asds_asds
asds_asds

Reputation: 1062

How to exit a blocking call of recv() on a thread from a different thread?

I have a code which runs two threads,

The first thread waits on the sender for data using recv() and then forwards the data to the receiver using send.

The second thread waits on the receiver for data using recv() and then forwards the data to the sender using send.

It is important that both of these work parallelly.

Suppose the sender disconnects, the first thread detects this and closes the connection.

How do I tell the second thread which is still waiting on the receiver that the connection has been closed and no further communication is required?

recieve_packet has been implemented using recv().

send_packet has been implemented using send().

sender_fd is the socket file descriptor of the sender.

reciever_fd is the socket file descriptor of the receiver.

void* sender_to_reciever(){
    int t1;packet* p = malloc(sizeof(packet));
    while((t1=recieve_packet(sender_fd,&p))!=0){
        send_packet(reciever_fd,p);
    }
    close(sender_fd);
}


void* reciever_to_sender(){
    int t1;packet* p = malloc(sizeof(packet));
    while((t1=recieve_packet(reciever_fd,&p))!=0){
        send_packet(sender_fd,p);
    }
    close(reciever_fd);
}

I don't want to change the implementation of the send_packet and recieve_packet function calls.

I tried closing both sender_fd and reciever_d if either while loop exits. It did not work, however.

Code for channel.c which handles both the sender and the reciever :-

#include "packets.c"


#define SERVER_PORT "8642"
#define QUEUE_LENGTH 10

void handle_connection(int);
void* sender_to_receiver();
void* receiver_to_sender();
int open_outgoing_connection(char*, char*);
int sender_fd;
int receiver_fd;

int main(int argc, char* argv[]){

    int sock_fd,new_fd,rv,yes;yes=1;
    struct addrinfo hints,*res;
    struct sockaddr_storage client_addr;
    socklen_t addr_size;
    char client_details[INET6_ADDRSTRLEN];


    struct sigaction sa;

    memset(&hints, 0, sizeof hints);
    hints.ai_family = AF_UNSPEC;
    hints.ai_socktype = SOCK_STREAM;
    hints.ai_flags = AI_PASSIVE;

    if((rv=getaddrinfo(NULL, SERVER_PORT, &hints, &res))!=0){
        printf("Error getaddrinfo : %s\n",gai_strerror(rv));
        return 1;
    }


    if((sock_fd = socket(res->ai_family, res->ai_socktype, res->ai_protocol))==-1){
        printf("Error socket file descriptor\n");
        return 1;
    }

    if(setsockopt(sock_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(int)) == -1){
        printf("Error setsockopt\n");
        return 1;
    }

    if(bind(sock_fd,res->ai_addr,res->ai_addrlen)==-1){
        close(sock_fd);
        printf("Error bind\n");
        return 1;
    }

    if(listen(sock_fd, QUEUE_LENGTH)==-1){
        printf("Error listen\n");
        return 1;
    }

    freeaddrinfo(res);

    sa.sa_handler = sigchld_handler;
    sigemptyset(&sa.sa_mask);
    sa.sa_flags = SA_RESTART;
    if(sigaction(SIGCHLD,&sa,NULL) == -1){
        printf("Error sigaction\n");
        return 1;
    }

    // Now we have prepared the socket(ip+port) for accepting incoming connections.
    printf("Server PID : %d\n",getpid());
    while(1){
        addr_size = sizeof client_addr;
        new_fd = accept(sock_fd,(struct sockaddr*)&client_addr, &addr_size);

        if(new_fd==-1){
            printf("Error Accepting Request %d\n",getpid());
            return 1;
        }

        inet_ntop(client_addr.ss_family,get_in_addr((struct sockaddr*)&client_addr),client_details,sizeof client_details);

        if(!fork()){ // Child Process
            close(sock_fd);
            printf("Connection Accepted From : %s by PID:%d\n",client_details,getpid());
            handle_connection(new_fd);
            exit(0);
        }
    }
    return 1;
}


void handle_connection(int socket_sender){
    packet* p = malloc(sizeof(packet));
    int t1;
    if((t1=receive_packet(socket_sender, &p))==0){
            printf("Closed Connection\n");
    }else{
        int socket_receiver;
        if((socket_receiver=open_outgoing_connection(p->destination_ip,p->destination_port))!=-1){      
            sender_fd = socket_sender;
            receiver_fd = socket_receiver;

            pthread_t str,rts;
            str     = pthread_self();
            rts   = pthread_self();

            pthread_create(&str,NULL,sender_to_receiver,NULL);
            pthread_create(&rts,NULL,receiver_to_sender,NULL);

            pthread_join(str,NULL);
            pthread_join(rts,NULL);

        }else{
            printf("Error Connecting to receiver\n");
        }
    }
}

void* sender_to_receiver(){
    int t1;packet* p = malloc(sizeof(packet));
    while((t1=receive_packet(sender_fd,&p))!=0){
        printf("SENDER\n");
        display_packet(p);
        send_packet(receiver_fd,p);
    }
    printf("Sender Disconnected\n");
    close(sender_fd);close(receiver_fd);
}
void* receiver_to_sender(){
    int t1;packet* p = malloc(sizeof(packet));
    while((t1=receive_packet(receiver_fd,&p))!=0){
        printf("Receiver\n");
        display_packet(p);
        send_packet(sender_fd,p);
    }
    printf("Receiver Disconnected\n");
    close(receiver_fd);close(sender_fd);
}

int open_outgoing_connection(char* ip, char* port){
    int gai;
    char server_ip[100];memset(server_ip,'\0',sizeof(server_ip));
    struct addrinfo hints,*server;
    memset(&hints,0,sizeof hints);
    hints.ai_family     = AF_UNSPEC;
    hints.ai_socktype   = SOCK_STREAM;
    int socket_fd;
    if((gai=getaddrinfo(ip,port,&hints,&server)) != 0){
        printf("GetAddrInfo Error: %s\n",gai_strerror(gai));
        return -1;
    }

    if((socket_fd = socket(server->ai_family, server->ai_socktype, server->ai_protocol)) == -1){
        printf("Socket Error\n");
        return -1;
    }

    if(connect(socket_fd,server->ai_addr,server->ai_addrlen) == -1){
        printf("Connect Error\n");
        return -1;
    }
    freeaddrinfo(server);
    inet_ntop(server->ai_family, get_in_addr((struct sockaddr*)server->ai_addr), server_ip, sizeof(server_ip));
    printf("Connected to: %s\n",server_ip);
    return socket_fd;
}

Code for reciver:-

#include "packets.c"

#define QUEUE_LENGTH 10

void handle_connection(int);
int main(int argc, char* argv[]){

    if(argc!=2){
        printf("Enter PORT\n");
        return 1;
    }

    int sock_fd,new_fd,rv,yes;yes=1;
    struct addrinfo hints,*res;
    struct sockaddr_storage client_addr;
    socklen_t addr_size;
    char client_details[INET6_ADDRSTRLEN];


    struct sigaction sa;

    memset(&hints, 0, sizeof hints);
    hints.ai_family = AF_UNSPEC;
    hints.ai_socktype = SOCK_STREAM;
    hints.ai_flags = AI_PASSIVE;

    if((rv=getaddrinfo(NULL, argv[1], &hints, &res))!=0){
        printf("Error getaddrinfo : %s\n",gai_strerror(rv));
        return 1;
    }


    if((sock_fd = socket(res->ai_family, res->ai_socktype, res->ai_protocol))==-1){
        printf("Error socket file descriptor\n");
        return 1;
    }

    if(setsockopt(sock_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(int)) == -1){
        printf("Error setsockopt\n");
        return 1;
    }

    if(bind(sock_fd,res->ai_addr,res->ai_addrlen)==-1){
        close(sock_fd);
        printf("Error bind\n");
        return 1;
    }

    if(listen(sock_fd, QUEUE_LENGTH)==-1){
        printf("Error listen\n");
        return 1;
    }

    freeaddrinfo(res);

    sa.sa_handler = sigchld_handler;
    sigemptyset(&sa.sa_mask);
    sa.sa_flags = SA_RESTART;
    if(sigaction(SIGCHLD,&sa,NULL) == -1){
        printf("Error sigaction\n");
        return 1;
    }

    // Now we have prepared the socket(ip+port) for accepting incoming connections.
    printf("Server PID : %d\n",getpid());
    while(1){
        addr_size = sizeof client_addr;
        new_fd = accept(sock_fd,(struct sockaddr*)&client_addr, &addr_size);

        if(new_fd==-1){
            printf("Error Accepting Request %d\n",getpid());
            return 1;
        }

        inet_ntop(client_addr.ss_family,get_in_addr((struct sockaddr*)&client_addr),client_details,sizeof client_details);

        if(!fork()){ // Child Process
            close(sock_fd);
            printf("Connection Accepted From : %s by PID:%d\n",client_details,getpid());
            handle_connection(new_fd);
            exit(0);
        }
        close(new_fd);
    }
    return 1;
}



void handle_connection(int socket_fd){
    int t1;packet* p = malloc(sizeof(packet));
    while((t1=receive_packet(socket_fd,&p))!=0){
        display_packet(p);
        p->message = "ACK";
        p->timestamp = get_time_in_ns();
        send_packet(socket_fd,p);
    }
    printf("Sender Disconnected\n");
}

Code for Sender:-

#include "packets.c"

#define CHANNEL_PORT "8642"
#define CHANNEL_IP   "127.0.0.1"

int socket_fd;
char* destination_ip;
char* destination_port;
int number_of_packets;
char message[MESSAGE_BUFFER_LEN];


void prepare_packet_header(packet *p){
    p->destination_ip=destination_ip;
    p->destination_port=destination_port;
    p->timestamp = get_time_in_ns();
    p->length=0;
}


void divide_message_and_send_packets(){
    if(number_of_packets>strlen(message)){
        number_of_packets = strlen(message);
    }
    int indi_len = strlen(message)/number_of_packets;
    int lm = 0;int i,j;
    packet* all_packets[number_of_packets];

    for(i=0;i<number_of_packets;i+=1)all_packets[i]=malloc(sizeof(packet));
    // HANDSHAKE
    prepare_packet_header(all_packets[0]);
    all_packets[0]->message="SYN";
    all_packets[0]->uid=-1;
    send_packet(socket_fd,all_packets[0]);
    // HANDSHAKE OVER
    for(i=0;i<number_of_packets;i+=1){
        // printf("Processing Packet: %d with message[%d:%d]\n",i,lm,lm+indi_len);
        if(i!=number_of_packets-1){
            char temp[indi_len+1];memset(temp,'\0', sizeof temp);
            for(j=lm;j<lm+indi_len;j+=1)temp[j-lm]=message[j];
            all_packets[i]->message = malloc(sizeof temp);
            strcpy(all_packets[i]->message, temp);
            all_packets[i]->uid = i;
        }else{
            char temp[strlen(message)-lm+1];memset(temp,'\0', sizeof temp);
            for(j=lm;j<strlen(message);j+=1)temp[j-lm]=message[j];
            all_packets[i]->message = malloc(sizeof temp);
            strcpy(all_packets[i]->message, temp);
            all_packets[i]->uid = i;
        }
        prepare_packet_header(all_packets[i]);lm+=indi_len;
        display_packet(all_packets[i]);
        send_packet(socket_fd,all_packets[i]);
    }

}

void main(int argc,char* argv[]){
    if(argc!=5){
        printf("Enter DESTINATION_IP DESTINATION_PORT NUMBER_OF_PACKETS MESSAGE\n");
        return;
    }

    strcpy(message,argv[4]);
    number_of_packets=atoi(argv[3]);
    destination_ip=malloc(sizeof argv[1] + 1);memset(destination_ip,'\0', sizeof destination_ip);
    destination_port=malloc(sizeof argv[2] + 1);memset(destination_port,'\0', sizeof destination_port);
    strcpy(destination_ip,argv[1]);
    strcpy(destination_port,argv[2]);
    int gai;
    char server_ip[100];memset(server_ip,'\0',sizeof(server_ip));
    struct addrinfo hints,*server;
    memset(&hints,0,sizeof hints);
    hints.ai_family     = AF_UNSPEC;
    hints.ai_socktype   = SOCK_STREAM;

    if((gai=getaddrinfo(CHANNEL_IP,CHANNEL_PORT,&hints,&server)) != 0){
        printf("GetAddrInfo Error: %s\n",gai_strerror(gai));
        return;
    }

    if((socket_fd = socket(server->ai_family, server->ai_socktype, server->ai_protocol)) == -1){
        printf("Socket Error\n");
    }

    if(connect(socket_fd,server->ai_addr,server->ai_addrlen) == -1){
        printf("Connect Error\n");
    }
    freeaddrinfo(server);
    inet_ntop(server->ai_family, get_in_addr((struct sockaddr*)server->ai_addr), server_ip, sizeof(server_ip));
    printf("Connected to: %s\n",server_ip);

    divide_message_and_send_packets();
    while(1){} // busy wait taht simulates future work
}

Code for packets.c (Contains functions related to the packets):-

#include "helper.c"

typedef struct StupidAssignment{
    long length;
    char* destination_ip;
    char* destination_port;
    long timestamp;
    long uid;
    char* message;
}packet;

int receive_packet(int socket,packet** p1){
    packet* p = *p1;
    int remaining=0;int i;
    int received=0;
    long content_length=0;
    remaining=11;
    char buffer[MESSAGE_BUFFER_LEN];memset(buffer,'\0',sizeof(buffer));
    while(remaining>0){
        int t1 = recv(socket, buffer+received, remaining, 0);
        if(t1==0)return 0;
        remaining-=t1;
        received+=t1;
    }
    content_length = read_long(buffer, received);

    received=0;
    remaining=content_length;p->length=content_length;
    memset(buffer,'\0',sizeof(buffer));
    while(remaining>0){
        int t1 = recv(socket, buffer+received, remaining, 0);
        if(t1==0)return 0;
        remaining-=t1;
        received+=t1;
    }

    char part[MESSAGE_BUFFER_LEN];memset(part,'\0',sizeof(part));int part_len=0;int nlmkr=0;
    for(i=0;i<=content_length;i+=1){
        if(buffer[i]=='\n' || i==content_length){
            nlmkr+=1;
            if(nlmkr==1)    read_char(&(p->destination_ip), part, part_len);
            else if(nlmkr==2)   read_char(&(p->destination_port), part, part_len);
            else if(nlmkr==3)   p->timestamp = read_long(part, part_len);
            else if(nlmkr==4)   p->uid = read_long(part, part_len);
            else if(nlmkr==6)   read_char(&(p->message), part, part_len);
            part_len=0;memset(part, '\0', sizeof part);

        }else{
            part[part_len++]=buffer[i];
        }
    }
    return 1;
}

void send_packet(int socket,packet *p){
    char temp[MESSAGE_BUFFER_LEN];memset(temp,'\0',sizeof temp);
    strcat(temp,p->destination_ip);strcat(temp,"\n");
    strcat(temp,p->destination_port);strcat(temp,"\n");
    snprintf(temp+strlen(temp),100,"%ld\n",p->timestamp);
    snprintf(temp+strlen(temp),100,"%ld\n",p->uid);
    // write_long(p->timestamp,temp);strcat(temp,"\n");
    // write_long(p->uid,temp);strcat(temp,"\n");
    strcat(temp,"\n");
    strcat(temp,p->message);
    char buffer[MESSAGE_BUFFER_LEN];memset(buffer, '\0', sizeof buffer);
    p->length = strlen(temp);
    snprintf(buffer,100,"%10ld\n",strlen(temp));
    strcat(buffer, temp);
    sendAll(buffer,socket);
}

void display_packet(packet* p){
    printf("----PACKET START----\n");
    printf("%ld\n",p->length);
    printf("%s\n",p->destination_ip);
    printf("%s\n",p->destination_port);
    printf("%ld\n",p->timestamp);
    printf("%ld\n",p->uid);
    printf("%s\n",p->message);
    printf("----PACKET END-----\n");
}

Code for helper.c (Some functions used by all the other codes):-

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>
#include <poll.h>
#include <arpa/inet.h>
#include <sys/wait.h>
#include <signal.h>
#include <time.h>
#include <sys/stat.h>
#include <ctype.h>
#include <fcntl.h>
#include <pthread.h>

#define MESSAGE_BUFFER_LEN 20480

void read_char(char** into, char* from, int length){
    *into = malloc(length+1);memset(*into, '\0', sizeof *into);
    strcpy(*into, from);
    // printf("%s\n",into);
}

long read_long(char* from, int length){
    int i;long temp=0;
    for(i=0;i<length;i+=1){
        if(isdigit(*(from+i))){
            temp=temp*10;temp+=(long)(*(from+i) - '0');         
        }
    }
    return temp;
}

void write_long(long t1,char* m){
    int mkr=0;
    char temp[100];memset(temp,'\0',sizeof temp);
    while(t1!=0){
        temp[mkr] = ((int)(t1%10)) + '0';
        t1 = t1/10;
    }
    for(mkr=strlen(temp)-1;mkr>=0;mkr-=1){
        *(m+strlen(m))=temp[mkr];
    }
}

void *get_in_addr(struct sockaddr* sa){
    if(sa->sa_family == AF_INET){
        return &(((struct sockaddr_in *)sa)->sin_addr);
    }else{
        return &(((struct sockaddr_in6*)sa)->sin6_addr);
    }
}


void sigchld_handler(int s){
    int saved_errno=errno;
    while(waitpid(-1, NULL, WNOHANG) > 0);
    errno = saved_errno;
}

int sendAll(char* data_to_send,int socket_fd){
    int bytesleft = strlen(data_to_send);
    int total=0;int n;
    while(bytesleft>0){
        n = send(socket_fd,data_to_send + total, bytesleft, 0);
        if(n==-1)break;
        total+=n;
        bytesleft-=n;
    }
    return n==-1?-1:0;
}

long get_time_in_ns(){
    struct timespec start;clock_gettime(CLOCK_REALTIME,&start);
    long ct = ((long)start.tv_sec)*1e9 + ((long)start.tv_nsec);
    return ct;
}

The code isn't documented at all.

Upvotes: 0

Views: 143

Answers (1)

David Schwartz
David Schwartz

Reputation: 182753

For TCP, call shutdown on the socket. To be more complete:

  1. Set some flag that the thread will check so it will know that a shutdown is in process when it becomes unblocked.
  2. Do whatever you need to do to shut the connection down, ultimately calling shutdown on the socket when you're done. If you need to call shutdown as part of your teardown process, do it. If not, when you're done with your teardown process (if any) shutdown the connection in both directions.
  3. Do not, under any circumstances, call close on the socket until you can 100% confirm that no thread is, or might be, trying to access the socket. This is extremely important.

For UDP, send a datagram to the socket. That will unblock the thread as it receives the dummy datagram.

Upvotes: 1

Related Questions