From: Stefan Sperling Subject: convert gotd repo_read.c and repo_write.c to single-client To: gameoftrees@openbsd.org Date: Thu, 29 Dec 2022 19:09:56 +0100 Because these processes are now started on demand per client connection there is no need to keep track of multiple clients anymore. Also, these processes can now exit when a disconnect event is received. diff 5e25db14db9eb20ee11b68048b45b3e0f54d50eb 7325ad4c5c711bc79e00d8bfebd5205dc4b79b39 commit - 5e25db14db9eb20ee11b68048b45b3e0f54d50eb commit + 7325ad4c5c711bc79e00d8bfebd5205dc4b79b39 blob - 84ad53e727b231bc0df858b1927be8682b139f1a blob + 70f72ff14ea0d93944e8b9e766316bd5cd5cb1bc --- gotd/repo_read.c +++ gotd/repo_read.c @@ -62,8 +62,7 @@ struct repo_read_client { int *temp_fds; } repo_read; -struct repo_read_client { - STAILQ_ENTRY(repo_read_client) entry; +static struct repo_read_client { uint32_t id; int fd; int delta_cache_fd; @@ -71,46 +70,8 @@ struct repo_read_client { int pack_pipe; struct gotd_object_id_array want_ids; struct gotd_object_id_array have_ids; -}; -STAILQ_HEAD(repo_read_clients, repo_read_client); +} repo_read_client; -static struct repo_read_clients repo_read_clients[GOTD_CLIENT_TABLE_SIZE]; -static SIPHASH_KEY clients_hash_key; - -static uint64_t -client_hash(uint32_t client_id) -{ - return SipHash24(&clients_hash_key, &client_id, sizeof(client_id)); -} - -static void -add_client(struct repo_read_client *client, uint32_t client_id, int fd) -{ - uint64_t slot; - - client->id = client_id; - client->fd = fd; - client->delta_cache_fd = -1; - client->pack_pipe = -1; - slot = client_hash(client->id) % nitems(repo_read_clients); - STAILQ_INSERT_HEAD(&repo_read_clients[slot], client, entry); -} - -static struct repo_read_client * -find_client(uint32_t client_id) -{ - uint64_t slot; - struct repo_read_client *c; - - slot = client_hash(client_id) % nitems(repo_read_clients); - STAILQ_FOREACH(c, &repo_read_clients[slot], entry) { - if (c->id == client_id) - return c; - } - - return NULL; -} - static volatile sig_atomic_t sigint_received; static volatile sig_atomic_t sigterm_received; @@ -292,9 +253,10 @@ list_refs(struct repo_read_client **client, struct ims } static const struct got_error * -list_refs(struct repo_read_client **client, struct imsg *imsg) +list_refs(struct imsg *imsg) { const struct got_error *err; + struct repo_read_client *client = &repo_read_client; struct got_reflist_head refs; struct got_reflist_entry *re; struct gotd_imsg_list_refs_internal ireq; @@ -314,15 +276,17 @@ list_refs(struct repo_read_client **client, struct ims return got_error(GOT_ERR_PRIVSEP_LEN); memcpy(&ireq, imsg->data, sizeof(ireq)); - *client = find_client(ireq.client_id); - if (*client) - return got_error_msg(GOT_ERR_CLIENT_ID, "duplicate client ID"); + if (ireq.client_id == 0) + return got_error(GOT_ERR_CLIENT_ID); + if (client->id != 0) { + return got_error_msg(GOT_ERR_CLIENT_ID, + "duplicate list-refs request"); + } + client->id = ireq.client_id; + client->fd = client_fd; + client->delta_cache_fd = -1; + client->pack_pipe = -1; - *client = calloc(1, sizeof(**client)); - if (*client == NULL) - return got_error_from_errno("calloc"); - add_client(*client, ireq.client_id, client_fd); - imsg_init(&ibuf, client_fd); err = got_ref_list(&refs, repo_read.repo, "", @@ -451,9 +415,10 @@ recv_want(struct repo_read_client **client, struct ims } static const struct got_error * -recv_want(struct repo_read_client **client, struct imsg *imsg) +recv_want(struct imsg *imsg) { const struct got_error *err; + struct repo_read_client *client = &repo_read_client; struct gotd_imsg_want iwant; size_t datalen; char hex[SHA1_DIGEST_STRING_LENGTH]; @@ -473,12 +438,8 @@ recv_want(struct repo_read_client **client, struct ims got_sha1_digest_to_str(id.sha1, hex, sizeof(hex))) log_debug("client wants %s", hex); - *client = find_client(iwant.client_id); - if (*client == NULL) - return got_error(GOT_ERR_CLIENT_ID); + imsg_init(&ibuf, client->fd); - imsg_init(&ibuf, (*client)->fd); - err = got_object_get_type(&obj_type, repo_read.repo, &id); if (err) return err; @@ -487,7 +448,7 @@ recv_want(struct repo_read_client **client, struct ims obj_type != GOT_OBJ_TYPE_TAG) return got_error(GOT_ERR_OBJ_TYPE); - err = record_object_id(&(*client)->want_ids, &id); + err = record_object_id(&client->want_ids, &id); if (err) return err; @@ -497,9 +458,10 @@ recv_have(struct repo_read_client **client, struct ims } static const struct got_error * -recv_have(struct repo_read_client **client, struct imsg *imsg) +recv_have(struct imsg *imsg) { const struct got_error *err; + struct repo_read_client *client = &repo_read_client; struct gotd_imsg_have ihave; size_t datalen; char hex[SHA1_DIGEST_STRING_LENGTH]; @@ -519,12 +481,8 @@ recv_have(struct repo_read_client **client, struct ims got_sha1_digest_to_str(id.sha1, hex, sizeof(hex))) log_debug("client has %s", hex); - *client = find_client(ihave.client_id); - if (*client == NULL) - return got_error(GOT_ERR_CLIENT_ID); + imsg_init(&ibuf, client->fd); - imsg_init(&ibuf, (*client)->fd); - err = got_object_get_type(&obj_type, repo_read.repo, &id); if (err) { if (err->code == GOT_ERR_NO_OBJ) { @@ -542,7 +500,7 @@ recv_have(struct repo_read_client **client, struct ims goto done; } - err = record_object_id(&(*client)->have_ids, &id); + err = record_object_id(&client->have_ids, &id); if (err) return err; @@ -604,9 +562,10 @@ receive_delta_cache_fd(struct repo_read_client **clien } static const struct got_error * -receive_delta_cache_fd(struct repo_read_client **client, struct imsg *imsg, +receive_delta_cache_fd(struct imsg *imsg, struct gotd_imsgev *iev) { + struct repo_read_client *client = &repo_read_client; struct gotd_imsg_send_packfile ireq; size_t datalen; @@ -620,22 +579,18 @@ receive_delta_cache_fd(struct repo_read_client **clien return got_error(GOT_ERR_PRIVSEP_LEN); memcpy(&ireq, imsg->data, sizeof(ireq)); - *client = find_client(ireq.client_id); - if (*client == NULL) - return got_error(GOT_ERR_CLIENT_ID); - - if ((*client)->delta_cache_fd != -1) + if (client->delta_cache_fd != -1) return got_error(GOT_ERR_PRIVSEP_MSG); - (*client)->delta_cache_fd = imsg->fd; - (*client)->report_progress = ireq.report_progress; + client->delta_cache_fd = imsg->fd; + client->report_progress = ireq.report_progress; return NULL; } static const struct got_error * -receive_pack_pipe(struct repo_read_client **client, struct imsg *imsg, - struct gotd_imsgev *iev) +receive_pack_pipe(struct imsg *imsg, struct gotd_imsgev *iev) { + struct repo_read_client *client = &repo_read_client; struct gotd_imsg_packfile_pipe ireq; size_t datalen; @@ -649,21 +604,18 @@ receive_pack_pipe(struct repo_read_client **client, st return got_error(GOT_ERR_PRIVSEP_LEN); memcpy(&ireq, imsg->data, sizeof(ireq)); - *client = find_client(ireq.client_id); - if (*client == NULL) - return got_error(GOT_ERR_CLIENT_ID); - if ((*client)->pack_pipe != -1) + if (client->pack_pipe != -1) return got_error(GOT_ERR_PRIVSEP_MSG); - (*client)->pack_pipe = imsg->fd; + client->pack_pipe = imsg->fd; return NULL; } static const struct got_error * -send_packfile(struct repo_read_client *client, struct imsg *imsg, - struct gotd_imsgev *iev) +send_packfile(struct imsg *imsg, struct gotd_imsgev *iev) { const struct got_error *err = NULL; + struct repo_read_client *client = &repo_read_client; struct gotd_imsg_packfile_done idone; uint8_t packsha1[SHA1_DIGEST_LENGTH]; char hex[SHA1_DIGEST_STRING_LENGTH]; @@ -722,9 +674,8 @@ recv_disconnect(struct imsg *imsg) const struct got_error *err = NULL; struct gotd_imsg_disconnect idisconnect; size_t datalen; - int client_fd, delta_cache_fd, pack_pipe; - struct repo_read_client *client = NULL; - uint64_t slot; + int delta_cache_fd, pack_pipe; + struct repo_read_client *client = &repo_read_client; datalen = imsg->hdr.len - IMSG_HEADER_SIZE; if (datalen != sizeof(idisconnect)) @@ -733,23 +684,14 @@ recv_disconnect(struct imsg *imsg) log_debug("client disconnecting"); - client = find_client(idisconnect.client_id); - if (client == NULL) - return got_error(GOT_ERR_CLIENT_ID); - - slot = client_hash(client->id) % nitems(repo_read_clients); - STAILQ_REMOVE(&repo_read_clients[slot], client, repo_read_client, - entry); free_object_ids(&client->have_ids); free_object_ids(&client->want_ids); - client_fd = client->fd; - delta_cache_fd = client->delta_cache_fd; - pack_pipe = client->pack_pipe; - free(client); - if (close(client_fd) == -1) + if (close(client->fd) == -1) err = got_error_from_errno("close"); + delta_cache_fd = client->delta_cache_fd; if (delta_cache_fd != -1 && close(delta_cache_fd) == -1 && err == NULL) return got_error_from_errno("close"); + pack_pipe = client->pack_pipe; if (pack_pipe != -1 && close(pack_pipe) == -1 && err == NULL) return got_error_from_errno("close"); return err; @@ -764,7 +706,7 @@ repo_read_dispatch(int fd, short event, void *arg) struct imsg imsg; ssize_t n; int shut = 0; - struct repo_read_client *client = NULL; + struct repo_read_client *client = &repo_read_client; if (event & EV_READ) { if ((n = imsg_read(ibuf)) == -1 && errno != EAGAIN) @@ -782,47 +724,50 @@ repo_read_dispatch(int fd, short event, void *arg) } while (err == NULL && check_cancelled(NULL) == NULL) { - client = NULL; if ((n = imsg_get(ibuf, &imsg)) == -1) fatal("%s: imsg_get", __func__); if (n == 0) /* No more messages. */ break; + if (imsg.hdr.type != GOTD_IMSG_LIST_REFS_INTERNAL && + client->id == 0) { + err = got_error(GOT_ERR_PRIVSEP_MSG); + break; + } + switch (imsg.hdr.type) { case GOTD_IMSG_LIST_REFS_INTERNAL: - err = list_refs(&client, &imsg); + err = list_refs(&imsg); if (err) log_warnx("%s: ls-refs: %s", repo_read.title, err->msg); break; case GOTD_IMSG_WANT: - err = recv_want(&client, &imsg); + err = recv_want(&imsg); if (err) log_warnx("%s: want-line: %s", repo_read.title, err->msg); break; case GOTD_IMSG_HAVE: - err = recv_have(&client, &imsg); + err = recv_have(&imsg); if (err) log_warnx("%s: have-line: %s", repo_read.title, err->msg); break; case GOTD_IMSG_SEND_PACKFILE: - err = receive_delta_cache_fd(&client, &imsg, iev); + err = receive_delta_cache_fd(&imsg, iev); if (err) log_warnx("%s: receiving delta cache: %s", repo_read.title, err->msg); break; case GOTD_IMSG_PACKFILE_PIPE: - err = receive_pack_pipe(&client, &imsg, iev); + err = receive_pack_pipe(&imsg, iev); if (err) { log_warnx("%s: receiving pack pipe: %s", repo_read.title, err->msg); break; } - if (client->pack_pipe == -1) - break; - err = send_packfile(client, &imsg, iev); + err = send_packfile(&imsg, iev); if (err) log_warnx("%s: sending packfile: %s", repo_read.title, err->msg); @@ -832,6 +777,7 @@ repo_read_dispatch(int fd, short event, void *arg) if (err) log_warnx("%s: disconnect: %s", repo_read.title, err->msg); + shut = 1; break; default: log_debug("%s: unexpected imsg %d", repo_read.title, @@ -845,7 +791,7 @@ repo_read_dispatch(int fd, short event, void *arg) if (!shut && check_cancelled(NULL) == NULL) { if (err && gotd_imsg_send_error_event(iev, PROC_REPO_READ, - client ? client->id : 0, err) == -1) { + client->id, err) == -1) { log_warnx("could not send error to parent: %s", err->msg); } @@ -869,8 +815,6 @@ repo_read_main(const char *title, const char *repo_pat repo_read.pack_fds = pack_fds; repo_read.temp_fds = temp_fds; - arc4random_buf(&clients_hash_key, sizeof(clients_hash_key)); - err = got_repo_open(&repo_read.repo, repo_path, NULL, pack_fds); if (err) goto done; blob - a8b8a48b519156b15cdfed829ea3090d0552fc1a blob + 965572537276de7c92606cb33b97d87e09190553 --- gotd/repo_write.c +++ gotd/repo_write.c @@ -78,8 +78,7 @@ struct repo_write_client { }; STAILQ_HEAD(gotd_ref_updates, gotd_ref_update); -struct repo_write_client { - STAILQ_ENTRY(repo_write_client) entry; +static struct repo_write_client { uint32_t id; int fd; int pack_pipe; @@ -88,48 +87,8 @@ struct repo_write_client { int packidx_fd; struct gotd_ref_updates ref_updates; int nref_updates; -}; -STAILQ_HEAD(repo_write_clients, repo_write_client); +} repo_write_client; -static struct repo_write_clients repo_write_clients[GOTD_CLIENT_TABLE_SIZE]; -static SIPHASH_KEY clients_hash_key; - -static uint64_t -client_hash(uint32_t client_id) -{ - return SipHash24(&clients_hash_key, &client_id, sizeof(client_id)); -} - -static void -add_client(struct repo_write_client *client, uint32_t client_id, int fd) -{ - uint64_t slot; - - client->id = client_id; - client->fd = fd; - client->pack_pipe = -1; - client->packidx_fd = -1; - STAILQ_INIT(&client->ref_updates); - client->nref_updates = 0; - slot = client_hash(client->id) % nitems(repo_write_clients); - STAILQ_INSERT_HEAD(&repo_write_clients[slot], client, entry); -} - -static struct repo_write_client * -find_client(uint32_t client_id) -{ - uint64_t slot; - struct repo_write_client *c; - - slot = client_hash(client_id) % nitems(repo_write_clients); - STAILQ_FOREACH(c, &repo_write_clients[slot], entry) { - if (c->id == client_id) - return c; - } - - return NULL; -} - static volatile sig_atomic_t sigint_received; static volatile sig_atomic_t sigterm_received; @@ -263,9 +222,10 @@ list_refs(struct repo_write_client **client, struct im } static const struct got_error * -list_refs(struct repo_write_client **client, struct imsg *imsg) +list_refs(struct imsg *imsg) { const struct got_error *err; + struct repo_write_client *client = &repo_write_client; struct got_reflist_head refs; struct got_reflist_entry *re; struct gotd_imsg_list_refs_internal ireq; @@ -284,15 +244,18 @@ list_refs(struct repo_write_client **client, struct im return got_error(GOT_ERR_PRIVSEP_LEN); memcpy(&ireq, imsg->data, sizeof(ireq)); - *client = find_client(ireq.client_id); - if (*client) - return got_error_msg(GOT_ERR_CLIENT_ID, "duplicate client ID"); + if (ireq.client_id == 0) + return got_error(GOT_ERR_CLIENT_ID); + if (client->id != 0) { + return got_error_msg(GOT_ERR_CLIENT_ID, + "duplicate list-refs request"); + } + client->id = ireq.client_id; + client->fd = client_fd; + client->pack_pipe = -1; + client->packidx_fd = -1; + client->nref_updates = 0; - *client = calloc(1, sizeof(**client)); - if (*client == NULL) - return got_error_from_errno("calloc"); - add_client(*client, ireq.client_id, client_fd); - imsg_init(&ibuf, client_fd); err = got_ref_list(&refs, repo_write.repo, "", @@ -361,9 +324,10 @@ recv_ref_update(struct repo_write_client **client, str } static const struct got_error * -recv_ref_update(struct repo_write_client **client, struct imsg *imsg) +recv_ref_update(struct imsg *imsg) { const struct got_error *err = NULL; + struct repo_write_client *client = &repo_write_client; struct gotd_imsg_ref_update iref; size_t datalen; char *refname = NULL; @@ -381,12 +345,8 @@ recv_ref_update(struct repo_write_client **client, str if (datalen != sizeof(iref) + iref.name_len) return got_error(GOT_ERR_PRIVSEP_LEN); - *client = find_client(iref.client_id); - if (*client == NULL) - return got_error(GOT_ERR_CLIENT_ID); + imsg_init(&ibuf, client->fd); - imsg_init(&ibuf, (*client)->fd); - refname = malloc(iref.name_len + 1); if (refname == NULL) return got_error_from_errno("malloc"); @@ -455,8 +415,8 @@ recv_ref_update(struct repo_write_client **client, str repo_write.pid); ref_update->ref = ref; - STAILQ_INSERT_HEAD(&(*client)->ref_updates, ref_update, entry); - (*client)->nref_updates++; + STAILQ_INSERT_HEAD(&client->ref_updates, ref_update, entry); + client->nref_updates++; ref = NULL; ref_update = NULL; done: @@ -853,10 +813,10 @@ report_pack_status(struct repo_write_client *client, } static const struct got_error * -report_pack_status(struct repo_write_client *client, - const struct got_error *unpack_err) +report_pack_status(const struct got_error *unpack_err) { const struct got_error *err = NULL; + struct repo_write_client *client = &repo_write_client; struct gotd_imsg_packfile_status istatus; struct ibuf *wbuf; struct imsgbuf ibuf; @@ -899,9 +859,10 @@ recv_packfile(struct repo_write_client **client, struc } static const struct got_error * -recv_packfile(struct repo_write_client **client, struct imsg *imsg) +recv_packfile(struct imsg *imsg) { const struct got_error *err = NULL, *unpack_err; + struct repo_write_client *client = &repo_write_client; struct gotd_imsg_recv_packfile ireq; FILE *tempfiles[3] = { NULL, NULL, NULL }; struct repo_tempfile { @@ -924,20 +885,15 @@ recv_packfile(struct repo_write_client **client, struc return got_error(GOT_ERR_PRIVSEP_LEN); memcpy(&ireq, imsg->data, sizeof(ireq)); - *client = find_client(ireq.client_id); - if (*client == NULL || STAILQ_EMPTY(&(*client)->ref_updates)) - return got_error(GOT_ERR_CLIENT_ID); - - if ((*client)->pack_pipe == -1 || - (*client)->packidx_fd == -1) + if (client->pack_pipe == -1 || client->packidx_fd == -1) return got_error(GOT_ERR_PRIVSEP_NO_FD); - imsg_init(&ibuf, (*client)->fd); + imsg_init(&ibuf, client->fd); if (imsg->fd == -1) return got_error(GOT_ERR_PRIVSEP_NO_FD); - pack = &(*client)->pack; + pack = &client->pack; memset(pack, 0, sizeof(*pack)); pack->fd = imsg->fd; err = got_delta_cache_alloc(&pack->delta_cache); @@ -972,10 +928,10 @@ recv_packfile(struct repo_write_client **client, struc goto done; log_debug("receiving pack data"); - unpack_err = recv_packdata(&pack_filesize, (*client)->pack_sha1, - (*client)->pack_pipe, pack->fd); + unpack_err = recv_packdata(&pack_filesize, client->pack_sha1, + client->pack_pipe, pack->fd); if (ireq.report_status) { - err = report_pack_status(*client, unpack_err); + err = report_pack_status(unpack_err); if (err) { /* Git clients hang up after sending the pack file. */ if (err->code == GOT_ERR_EOF) @@ -993,23 +949,23 @@ recv_packfile(struct repo_write_client **client, struc log_debug("begin indexing pack (%lld bytes in size)", (long long)pack->filesize); - err = got_pack_index(pack, (*client)->packidx_fd, - tempfiles[0], tempfiles[1], tempfiles[2], (*client)->pack_sha1, + err = got_pack_index(pack, client->packidx_fd, + tempfiles[0], tempfiles[1], tempfiles[2], client->pack_sha1, pack_index_progress, NULL, &rl); if (err) goto done; log_debug("done indexing pack"); - if (fsync((*client)->packidx_fd) == -1) { + if (fsync(client->packidx_fd) == -1) { err = got_error_from_errno("fsync"); goto done; } - if (lseek((*client)->packidx_fd, 0L, SEEK_SET) == -1) + if (lseek(client->packidx_fd, 0L, SEEK_SET) == -1) err = got_error_from_errno("lseek"); done: - if (close((*client)->pack_pipe) == -1 && err == NULL) + if (close(client->pack_pipe) == -1 && err == NULL) err = got_error_from_errno("close"); - (*client)->pack_pipe = -1; + client->pack_pipe = -1; for (i = 0; i < nitems(repo_tempfiles); i++) { struct repo_tempfile *t = &repo_tempfiles[i]; if (t->idx != -1) @@ -1026,9 +982,10 @@ verify_packfile(struct repo_write_client *client) } static const struct got_error * -verify_packfile(struct repo_write_client *client) +verify_packfile(void) { const struct got_error *err = NULL, *close_err; + struct repo_write_client *client = &repo_write_client; struct gotd_ref_update *ref_update; struct got_packidx *packidx = NULL; struct stat sb; @@ -1085,8 +1042,9 @@ install_packfile(struct repo_write_client *client, str } static const struct got_error * -install_packfile(struct repo_write_client *client, struct gotd_imsgev *iev) +install_packfile(struct gotd_imsgev *iev) { + struct repo_write_client *client = &repo_write_client; struct gotd_imsg_packfile_install inst; int ret; @@ -1103,9 +1061,9 @@ send_ref_updates_start(struct repo_write_client *clien } static const struct got_error * -send_ref_updates_start(struct repo_write_client *client, int nref_updates, - struct gotd_imsgev *iev) +send_ref_updates_start(int nref_updates, struct gotd_imsgev *iev) { + struct repo_write_client *client = &repo_write_client; struct gotd_imsg_ref_updates_start istart; int ret; @@ -1123,9 +1081,9 @@ send_ref_update(struct repo_write_client *client, static const struct got_error * -send_ref_update(struct repo_write_client *client, - struct gotd_ref_update *ref_update, struct gotd_imsgev *iev) +send_ref_update(struct gotd_ref_update *ref_update, struct gotd_imsgev *iev) { + struct repo_write_client *client = &repo_write_client; struct gotd_imsg_ref_update iref; const char *refname = got_ref_get_name(ref_update->ref); struct ibuf *wbuf; @@ -1157,17 +1115,18 @@ update_refs(struct repo_write_client *client, struct g } static const struct got_error * -update_refs(struct repo_write_client *client, struct gotd_imsgev *iev) +update_refs(struct gotd_imsgev *iev) { const struct got_error *err = NULL; + struct repo_write_client *client = &repo_write_client; struct gotd_ref_update *ref_update; - err = send_ref_updates_start(client, client->nref_updates, iev); + err = send_ref_updates_start(client->nref_updates, iev); if (err) return err; STAILQ_FOREACH(ref_update, &client->ref_updates, entry) { - err = send_ref_update(client, ref_update, iev); + err = send_ref_update(ref_update, iev); if (err) goto done; } @@ -1181,9 +1140,8 @@ recv_disconnect(struct imsg *imsg) const struct got_error *err = NULL; struct gotd_imsg_disconnect idisconnect; size_t datalen; - int client_fd = -1, pack_pipe = -1, idxfd = -1; - struct repo_write_client *client = NULL; - uint64_t slot; + int pack_pipe = -1, idxfd = -1; + struct repo_write_client *client = &repo_write_client; datalen = imsg->hdr.len - IMSG_HEADER_SIZE; if (datalen != sizeof(idisconnect)) @@ -1192,13 +1150,6 @@ recv_disconnect(struct imsg *imsg) log_debug("client disconnecting"); - client = find_client(idisconnect.client_id); - if (client == NULL) - return got_error(GOT_ERR_CLIENT_ID); - - slot = client_hash(client->id) % nitems(repo_write_clients); - STAILQ_REMOVE(&repo_write_clients[slot], client, repo_write_client, - entry); while (!STAILQ_EMPTY(&client->ref_updates)) { struct gotd_ref_update *ref_update; ref_update = STAILQ_FIRST(&client->ref_updates); @@ -1207,23 +1158,21 @@ recv_disconnect(struct imsg *imsg) free(ref_update); } err = got_pack_close(&client->pack); - client_fd = client->fd; + if (client->fd != -1 && close(client->fd) == -1) + err = got_error_from_errno("close"); pack_pipe = client->pack_pipe; - idxfd = client->packidx_fd; - free(client); - if (client_fd != -1 && close(client_fd) == -1) - err = got_error_from_errno("close"); if (pack_pipe != -1 && close(pack_pipe) == -1 && err == NULL) err = got_error_from_errno("close"); + idxfd = client->packidx_fd; if (idxfd != -1 && close(idxfd) == -1 && err == NULL) err = got_error_from_errno("close"); return err; } static const struct got_error * -receive_pack_pipe(struct repo_write_client **client, struct imsg *imsg, - struct gotd_imsgev *iev) +receive_pack_pipe(struct imsg *imsg, struct gotd_imsgev *iev) { + struct repo_write_client *client = &repo_write_client; struct gotd_imsg_packfile_pipe ireq; size_t datalen; @@ -1237,20 +1186,17 @@ receive_pack_pipe(struct repo_write_client **client, s return got_error(GOT_ERR_PRIVSEP_LEN); memcpy(&ireq, imsg->data, sizeof(ireq)); - *client = find_client(ireq.client_id); - if (*client == NULL) - return got_error(GOT_ERR_CLIENT_ID); - if ((*client)->pack_pipe != -1) + if (client->pack_pipe != -1) return got_error(GOT_ERR_PRIVSEP_MSG); - (*client)->pack_pipe = imsg->fd; + client->pack_pipe = imsg->fd; return NULL; } static const struct got_error * -receive_pack_idx(struct repo_write_client **client, struct imsg *imsg, - struct gotd_imsgev *iev) +receive_pack_idx(struct imsg *imsg, struct gotd_imsgev *iev) { + struct repo_write_client *client = &repo_write_client; struct gotd_imsg_packidx_file ireq; size_t datalen; @@ -1264,13 +1210,10 @@ receive_pack_idx(struct repo_write_client **client, st return got_error(GOT_ERR_PRIVSEP_LEN); memcpy(&ireq, imsg->data, sizeof(ireq)); - *client = find_client(ireq.client_id); - if (*client == NULL) - return got_error(GOT_ERR_CLIENT_ID); - if ((*client)->packidx_fd != -1) + if (client->packidx_fd != -1) return got_error(GOT_ERR_PRIVSEP_MSG); - (*client)->packidx_fd = imsg->fd; + client->packidx_fd = imsg->fd; return NULL; } @@ -1281,7 +1224,7 @@ repo_write_dispatch(int fd, short event, void *arg) struct gotd_imsgev *iev = arg; struct imsgbuf *ibuf = &iev->ibuf; struct imsg imsg; - struct repo_write_client *client = NULL; + struct repo_write_client *client = &repo_write_client; ssize_t n; int shut = 0; @@ -1306,21 +1249,27 @@ repo_write_dispatch(int fd, short event, void *arg) if (n == 0) /* No more messages. */ break; + if (imsg.hdr.type != GOTD_IMSG_LIST_REFS_INTERNAL && + client->id == 0) { + err = got_error(GOT_ERR_PRIVSEP_MSG); + break; + } + switch (imsg.hdr.type) { case GOTD_IMSG_LIST_REFS_INTERNAL: - err = list_refs(&client, &imsg); + err = list_refs(&imsg); if (err) log_warnx("%s: ls-refs: %s", repo_write.title, err->msg); break; case GOTD_IMSG_REF_UPDATE: - err = recv_ref_update(&client, &imsg); + err = recv_ref_update(&imsg); if (err) log_warnx("%s: ref-update: %s", repo_write.title, err->msg); break; case GOTD_IMSG_PACKFILE_PIPE: - err = receive_pack_pipe(&client, &imsg, iev); + err = receive_pack_pipe(&imsg, iev); if (err) { log_warnx("%s: receiving pack pipe: %s", repo_write.title, err->msg); @@ -1328,7 +1277,7 @@ repo_write_dispatch(int fd, short event, void *arg) } break; case GOTD_IMSG_PACKIDX_FILE: - err = receive_pack_idx(&client, &imsg, iev); + err = receive_pack_idx(&imsg, iev); if (err) { log_warnx("%s: receiving pack index: %s", repo_write.title, err->msg); @@ -1336,25 +1285,25 @@ repo_write_dispatch(int fd, short event, void *arg) } break; case GOTD_IMSG_RECV_PACKFILE: - err = recv_packfile(&client, &imsg); + err = recv_packfile(&imsg); if (err) { log_warnx("%s: receive packfile: %s", repo_write.title, err->msg); break; } - err = verify_packfile(client); + err = verify_packfile(); if (err) { log_warnx("%s: verify packfile: %s", repo_write.title, err->msg); break; } - err = install_packfile(client, iev); + err = install_packfile(iev); if (err) { log_warnx("%s: install packfile: %s", repo_write.title, err->msg); break; } - err = update_refs(client, iev); + err = update_refs(iev); if (err) { log_warnx("%s: update refs: %s", repo_write.title, err->msg); @@ -1365,6 +1314,7 @@ repo_write_dispatch(int fd, short event, void *arg) if (err) log_warnx("%s: disconnect: %s", repo_write.title, err->msg); + shut = 1; break; default: log_debug("%s: unexpected imsg %d", repo_write.title, @@ -1378,7 +1328,7 @@ repo_write_dispatch(int fd, short event, void *arg) if (!shut && check_cancelled(NULL) == NULL) { if (err && gotd_imsg_send_error_event(iev, PROC_REPO_WRITE, - client ? client->id : 0, err) == -1) { + client->id, err) == -1) { log_warnx("could not send error to parent: %s", err->msg); } @@ -1402,7 +1352,7 @@ repo_write_main(const char *title, const char *repo_pa repo_write.pack_fds = pack_fds; repo_write.temp_fds = temp_fds; - arc4random_buf(&clients_hash_key, sizeof(clients_hash_key)); + STAILQ_INIT(&repo_write_client.ref_updates); err = got_repo_open(&repo_write.repo, repo_path, NULL, pack_fds); if (err)