diff --git a/src/usb/hid/hid.c b/src/usb/hid/hid.c index 0049499..d9887cb 100644 --- a/src/usb/hid/hid.c +++ b/src/usb/hid/hid.c @@ -124,16 +124,20 @@ void tud_hid_set_report_cb(uint8_t itf, uint8_t report_id, hid_report_type_t rep usb_rx(buffer, bufsize); } -void u2f_error(uint8_t error) { +int u2f_error(uint8_t error) { u2f_resp = (U2FHID_FRAME *)usb_get_tx(); u2f_resp->cid = u2f_req->cid; u2f_resp->init.cmd = U2FHID_ERROR; u2f_resp->init.bcntl = 1; - u2f_resp->init.data[0] = ERR_INVALID_CMD; + u2f_resp->init.data[0] = error; hid_write(64); + usb_clear_rx(); + return 0; } uint8_t last_cmd = 0; +uint8_t last_seq = 0; +U2FHID_FRAME last_req; int driver_process_usb_packet(uint16_t read) { int apdu_sent = 0; @@ -141,8 +145,14 @@ int driver_process_usb_packet(uint16_t read) { { DEBUG_PAYLOAD(usb_get_rx(),64); memset(u2f_resp, 0, sizeof(U2FHID_FRAME)); + if (u2f_req->cid == 0x0 || (u2f_req->cid == CID_BROADCAST && u2f_req->init.cmd != U2FHID_INIT)) + return u2f_error(ERR_SYNC_FAIL); if (FRAME_TYPE(u2f_req) == TYPE_INIT) { + if (MSG_LEN(u2f_req) > U2F_MAX_PACKET_SIZE) + return u2f_error(ERR_INVALID_LEN); + if (msg_packet.len > 0 && last_req.cid != u2f_req->cid) //We are in a transaction + return u2f_error(ERR_CHANNEL_BUSY); printf("command %x\n", FRAME_CMD(u2f_req)); printf("len %d\n", MSG_LEN(u2f_req)); msg_packet.len = msg_packet.current_len = 0; @@ -152,12 +162,24 @@ int driver_process_usb_packet(uint16_t read) { memcpy(msg_packet.data + msg_packet.current_len, u2f_req->init.data, 64-7); msg_packet.current_len += 64 - 7; } + memcpy(&last_req, u2f_req, sizeof(U2FHID_FRAME)); last_cmd = u2f_req->init.cmd; + last_seq = 0; } else { - memcpy(msg_packet.data + msg_packet.current_len, u2f_req->cont.data, MIN(64 - 5, msg_packet.len - msg_packet.current_len)); - msg_packet.current_len += MIN(64 - 5, msg_packet.len - msg_packet.current_len); + if (msg_packet.len == 0) //Received a cont with a prior init pkt + return 0; + if (last_seq != u2f_req->cont.seq) + return u2f_error(ERR_INVALID_SEQ); + if (last_req.cid == u2f_req->cid) { + memcpy(msg_packet.data + msg_packet.current_len, u2f_req->cont.data, MIN(64 - 5, msg_packet.len - msg_packet.current_len)); + msg_packet.current_len += MIN(64 - 5, msg_packet.len - msg_packet.current_len); + memcpy(&last_req, u2f_req, sizeof(U2FHID_FRAME)); + } + //else // Received a cont from another channel. Silently discard + last_seq++; } + if (u2f_req->init.cmd == U2FHID_INIT) { u2f_resp = (U2FHID_FRAME *)usb_get_tx(); U2FHID_INIT_REQ *req = (U2FHID_INIT_REQ *)u2f_req->init.data; @@ -180,7 +202,7 @@ int driver_process_usb_packet(uint16_t read) { } else if (u2f_req->init.cmd == U2FHID_WINK) { if (MSG_LEN(u2f_req) != 0) { - u2f_error(ERR_INVALID_LEN); + return u2f_error(ERR_INVALID_LEN); } u2f_resp = (U2FHID_FRAME *)usb_get_tx(); memcpy(u2f_resp, u2f_req, sizeof(U2FHID_FRAME)); @@ -196,7 +218,7 @@ int driver_process_usb_packet(uint16_t read) { } else { if (msg_packet.len == 0) - u2f_error(ERR_INVALID_CMD); + return u2f_error(ERR_INVALID_CMD); } // echo back anything we received from host //tud_hid_report(0, buffer, bufsize);