/*
 * Pixelworks Flash Upgrade Gadget
 *
 * Copyright (C) 2016 Pixelworks, Inc.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <usb/gadget.h>
#include <usb/composite.h>
#include <usb/ch9.h>
#include <common.h>
#include <malloc.h>
#include <progress.h>
#include <dma.h>
#include <errno.h>
#include <fcntl.h>
#include <libbb.h>
#include <init.h>
#include <fs.h>
#include <linux/mtd/mtd.h>

//#define DEBUG
#ifdef DEBUG
#define dprintf(fmt, args...) printf("%s:%d:" fmt, __FUNCTION__, __LINE__, ##args)
#else
#define dprintf(args...) do { } while (0)
#endif


#define FLASHDEV "m25p0"
#define XFERLEN 65536

struct f_pwfu {
	struct usb_function func;

	struct usb_ep *in;
	struct usb_ep *out;
	struct usb_request *inreq;
	struct usb_request *outreq;

	int tmpfd;
	u32 addr;
	u32 size;
};

static inline struct f_pwfu *func_to_pwfu(struct usb_function *f)
{
	return container_of(f, struct f_pwfu, func);
}


static int pwfu_exit;

#define STRING_MANF 0
#define STRING_PROD 1
#define STRING_CONF 2
#define STRING_INTF 3

static struct usb_string pwfu_string_tab[] = {
	[STRING_MANF].s = "Pixelworks",
	[STRING_PROD].s = "TopazEH",
	[STRING_CONF].s = "USB Flash Upgrade",
	[STRING_INTF].s = "USB Flash Upgrade",
	{  } /* end of list */
};

static struct usb_gadget_strings pwfu_strings_en_us = {
	.language = 0x0409,
	.strings = pwfu_string_tab,
};

static struct usb_gadget_strings *pwfu_strings[] = {
	&pwfu_strings_en_us,
	NULL,
};

static struct usb_interface_descriptor pwfu_interface_desc = {
	.bLength =		USB_DT_INTERFACE_SIZE,
	.bDescriptorType =	USB_DT_INTERFACE,
	/* .bInterfaceNumber = DYNAMIC */
	.bNumEndpoints =	2,
	.bInterfaceClass =	USB_CLASS_VENDOR_SPEC,
	.bInterfaceSubClass =	0x00,
	.bInterfaceProtocol =	0x00,
};

static struct usb_endpoint_descriptor pwfu_bulk_in_fs_desc = {
	.bLength =		USB_DT_ENDPOINT_SIZE,
	.bDescriptorType =	USB_DT_ENDPOINT,
	.bEndpointAddress =	USB_DIR_IN,
	.wMaxPacketSize =	cpu_to_le16(64),
	.bmAttributes =		USB_ENDPOINT_XFER_BULK,
};

static struct usb_endpoint_descriptor pwfu_bulk_out_fs_desc = {
	.bLength =		USB_DT_ENDPOINT_SIZE,
	.bDescriptorType =	USB_DT_ENDPOINT,
	.bEndpointAddress =	USB_DIR_OUT,
	.wMaxPacketSize =	cpu_to_le16(64),
	.bmAttributes =		USB_ENDPOINT_XFER_BULK,
};

static struct usb_descriptor_header *pwfu_fs_function[] = {
	(struct usb_descriptor_header *) &pwfu_interface_desc,
	(struct usb_descriptor_header *) &pwfu_bulk_in_fs_desc,
	(struct usb_descriptor_header *) &pwfu_bulk_out_fs_desc,
	NULL,
};

static struct usb_endpoint_descriptor pwfu_bulk_in_hs_desc = {
	.bLength =		USB_DT_ENDPOINT_SIZE,
	.bDescriptorType =	USB_DT_ENDPOINT,
	.bEndpointAddress =	USB_DIR_IN,
	.wMaxPacketSize =	cpu_to_le16(512),
	.bmAttributes =		USB_ENDPOINT_XFER_BULK,
};

static struct usb_endpoint_descriptor pwfu_bulk_out_hs_desc = {
	.bLength =		USB_DT_ENDPOINT_SIZE,
	.bDescriptorType =	USB_DT_ENDPOINT,
	.bEndpointAddress =	USB_DIR_OUT,
	.wMaxPacketSize =	cpu_to_le16(512),
	.bmAttributes =		USB_ENDPOINT_XFER_BULK,
};

static struct usb_descriptor_header *pwfu_hs_function[] = {
	(struct usb_descriptor_header *) &pwfu_interface_desc,
	(struct usb_descriptor_header *) &pwfu_bulk_in_hs_desc,
	(struct usb_descriptor_header *) &pwfu_bulk_out_hs_desc,
	NULL,
};

static void pwfu_in_complete(struct usb_ep *ep, struct usb_request *req)
{
	dprintf("\n");

	switch (req->status) {
	case 0:
		dprintf("IN req completed\n");
		break;
	case -ESHUTDOWN:
		dprintf("disconnect\n");
		break;
	default:
		dprintf("unexpected status %d\n", req->status);
	}
}

static bool pwfu_cmd_check(u8 *buf, u8 code)
{
	int i;

	for (i = 0; i < 100; i++)
		if (buf[i] != code)
			return false;
	return true;
}

static void pwfu_out_complete(struct usb_ep *ep, struct usb_request *req)
{
	struct f_pwfu *pwfu = ep->driver_data;
	enum { ENDOFDATA, START, DATA, CONFIRM, DONE } cmd = DATA;
	u8 *buf = (u8 *)req->buf;
	int rc, len, total;
	u32 size;
	struct cdev *cdev;
	int flashfd = -1;

	dprintf("req->status = %d, %u bytes\n", req->status, req->actual);

	if (req->status == -ESHUTDOWN) {
		dprintf("disconnect or reset\n");
		return;
	}
	else if (req->status) {
		dprintf("unexpected status %d\n", req->status);
		return;
	}

	if (req->actual == 105) { // command packet
		if (pwfu_cmd_check(buf, 0xa5))
			cmd = START;
		else if (pwfu_cmd_check(buf, 0xa6))
			cmd = ENDOFDATA;
		else if (pwfu_cmd_check(buf, 0xa7))
			cmd = CONFIRM;
		else if (pwfu_cmd_check(buf, 0xa8))
			cmd = DONE;
		else
			cmd = DATA;
	}
	switch (cmd) {
	case START:
		memcpy(&pwfu->addr, &buf[100], 4);
		printf("START 0x%08x\n", pwfu->addr);
		pwfu->size = 0;
		// create temporary memory partition for download, large enough flash image
		cdev = devfs_add_partition("mem", 0x8000000, 0x04000000,
					   DEVFS_PARTITION_PWFU, "pwfu.img");
		rc = IS_ERR(cdev) ? PTR_ERR(cdev) : 0;
		if (rc && rc != -EEXIST) {
			printf("can't add memory partition: %d\n", rc);
			break;
		}
		if (pwfu->tmpfd >= 0)
			close(pwfu->tmpfd);
		pwfu->tmpfd = open("/dev/pwfu.img", O_RDWR);
		if (pwfu->tmpfd < 0)
			printf("ERROR: can create tmpfile\n");
		break;
	case DATA:
		pwfu->size += req->actual;
		printf("DATA %u bytes, %u total\n", req->actual, pwfu->size);
		if (req->actual) {
			rc = write(pwfu->tmpfd, buf, req->actual);
			if (rc != (int)req->actual)
				printf("ERROR: can write tmpfile: %d\n", rc);
		}
		break;
	case ENDOFDATA:
		printf("XFER END, %u bytes\n", pwfu->size);
		if (!pwfu->size || pwfu->tmpfd < 0) {
			printf("XFER END: error not tmpfile\n");
			goto err;
		}
		// create temporary partition for flash
		// first get flash eraseblock size for rounding up the size
		cdev = cdev_by_name(FLASHDEV);
		if (!cdev || !cdev->mtd) {
			printf("can't find flash device '%s'\n", FLASHDEV);
			goto err;
		}
		size = ALIGN(pwfu->size, cdev->mtd->erasesize);
		cdev = devfs_add_partition(FLASHDEV, pwfu->addr, size,
					   DEVFS_PARTITION_PWFU, "pwfu");
		rc = IS_ERR(cdev) ? PTR_ERR(cdev) : 0;
		if (rc) {
			printf("can't add flash partition: %d\n", rc);
			goto err;
		}
		flashfd = open("/dev/pwfu", O_WRONLY);
		if (flashfd < 0) {
			printf("ERROR: can't open flash: %d\n", flashfd);
			goto err;
		}
		printf("erasing flash\n");
		rc = erase(flashfd, ~0, 0);
		if (rc) {
			printf("ERROR: can't erase flash: %d\n", rc);
			goto err;
		}
		printf("writing flash\n");
		buf = xmalloc(4096);
		init_progression_bar(pwfu->size);
		if (lseek(pwfu->tmpfd, 0, SEEK_SET) < 0) {
			printf("ERROR: can't rewind tmp file\n");
			goto err;
		}
		for (total = 0; total < pwfu->size; ) {
			len = min(4096u, pwfu->size);
			if (!len)
				break;
			rc = read(pwfu->tmpfd, buf, len);
			if (rc < 0)
				break;
			total += rc;
			rc = write(flashfd, buf, rc);
			if (rc < 0)
				break;
			show_progress(total);
		}
		printf("\n");
		free(buf);
		if (rc < 0)
			printf("ERROR: flash failed: %d\n", rc);
		else
			printf("flash write complete\n");
err:
		close(flashfd);
		close(pwfu->tmpfd);
		pwfu->tmpfd = -1;
		rc = devfs_del_partition("pwfu.img");
		if (rc)
			printf("can't delete memory partition: %d\n", rc);
		rc = devfs_del_partition("pwfu");
		if (rc)
			printf("can't delete flash partition: %d\n", rc);
		break;
	case CONFIRM:
		dprintf("CONFIRM\n");
		break;
	case DONE:
		dprintf("DONE\n");
		pwfu_exit = 1;
		break;
	}

	rc = usb_ep_queue(ep, req);
	if (rc < 0)
		dprintf("usb_ep_queue failed: %d\n", rc);
}

static int pwfu_bind(struct usb_configuration *c, struct usb_function *f)
{
	int id;
	struct f_pwfu *pwfu = func_to_pwfu(f);
	struct usb_composite_dev *cdev = c->cdev;
	struct usb_ep *ep;
	int rc;

	dprintf("\n");
	id = usb_interface_id(c, f);
	if (id < 0)
		return id;
	pwfu_interface_desc.bInterfaceNumber = id;

	ep = usb_ep_autoconfig(cdev->gadget, &pwfu_bulk_in_fs_desc);
	if (!ep)
		return -ENODEV;
	pwfu->in = ep;
	ep->driver_data = pwfu;

	ep = usb_ep_autoconfig(cdev->gadget, &pwfu_bulk_out_fs_desc);
	if (!ep)
		return -ENODEV;
	pwfu->out = ep;
	ep->driver_data = pwfu;

	pwfu_bulk_in_hs_desc.bEndpointAddress = pwfu_bulk_in_fs_desc.bEndpointAddress;
	pwfu_bulk_out_hs_desc.bEndpointAddress = pwfu_bulk_out_fs_desc.bEndpointAddress;

	rc = usb_assign_descriptors(f, pwfu_fs_function, pwfu_hs_function, NULL);

	pwfu->inreq = usb_ep_alloc_request(pwfu->in);
	pwfu->inreq->buf = dma_alloc(XFERLEN);
	pwfu->inreq->length = XFERLEN;
	pwfu->inreq->complete = pwfu_in_complete;

	pwfu->outreq = usb_ep_alloc_request(pwfu->out);
	pwfu->outreq->buf = dma_alloc(XFERLEN);
	pwfu->outreq->length = XFERLEN;
	pwfu->outreq->complete = pwfu_out_complete;

	return 0;
}

static void pwfu_unbind(struct usb_configuration *c, struct usb_function *f)
{
	struct f_pwfu *pwfu = func_to_pwfu(f);

	dprintf("\n");
	if (pwfu->tmpfd >= 0)
		close(pwfu->tmpfd);
	devfs_del_partition("pwfu.img");
	usb_ep_dequeue(pwfu->out, pwfu->outreq);
	usb_ep_dequeue(pwfu->in, pwfu->inreq);
	free(pwfu->outreq->buf);
	free(pwfu->inreq->buf);
	usb_ep_free_request(pwfu->out, pwfu->outreq);
	usb_ep_free_request(pwfu->in, pwfu->inreq);
	pwfu->outreq = NULL;
	pwfu->inreq = NULL;

	usb_free_all_descriptors(f);
}

static void pwfu_disable(struct usb_function *f)
{
	struct f_pwfu *pwfu = func_to_pwfu(f);

	dprintf("\n");
	usb_ep_disable(pwfu->out);
	usb_ep_disable(pwfu->in);
}

static int pwfu_set_alt(struct usb_function *f, unsigned intf, unsigned alt)
{
	struct f_pwfu *pwfu = func_to_pwfu(f);
	struct usb_composite_dev *cdev = f->config->cdev;
	int rc;

	dprintf("\n");
	rc = config_ep_by_speed(cdev->gadget, f, pwfu->in);
	if (rc < 0) {
		dprintf("usb_ep_config IN failed: %d\n", rc);
		goto err;
	}
	rc = usb_ep_enable(pwfu->in);
	if (rc < 0) {
		dprintf("usb_ep_enable IN failed: %d\n", rc);
		goto err;
	}
	rc = config_ep_by_speed(cdev->gadget, f, pwfu->out);
	if (rc < 0) {
		dprintf("usb_ep_config OUT failed: %d\n", rc);
		goto err;
	}
	rc = usb_ep_enable(pwfu->out);
	if (rc < 0) {
		dprintf("usb_ep_enable OUT failed: %d\n", rc);
		goto err;
	}
	rc = usb_ep_queue(pwfu->out, pwfu->outreq);
	if (rc < 0) {
		dprintf("usb_ep_queue failed: %d\n", rc);
		goto err;
	}
	return 0;

err:
	pwfu_disable(f);
	return rc;
}

static int pwfu_setup(struct usb_function *f, const struct usb_ctrlrequest *ctrl)
{
	dprintf("\n");
	return -EOPNOTSUPP;
}

static void pwfu_unbind_config(struct usb_configuration *c)
{
	dprintf("\n");
}

static struct usb_configuration pwfu_config_driver = {
	.label			= "USB PWFU",
	.unbind			= pwfu_unbind_config,
	.bConfigurationValue	= 1,
	.bmAttributes		= USB_CONFIG_ATT_SELFPOWER,
};

static struct usb_device_descriptor pwfu_dev_desc = {
	.bLength		= USB_DT_DEVICE_SIZE,
	.bDescriptorType	= USB_DT_DEVICE,
	.bcdUSB			= 0x0200,
	.bDeviceClass		= 0x00,
	.bDeviceSubClass	= 0x00,
	.bDeviceProtocol	= 0x00,
	.idVendor		= 0x0471,
	.idProduct		= 0x0010,
	.bcdDevice		= 0x0002,
	.bNumConfigurations	= 0x01,
};

static struct usb_function_instance *pwfu_fi;
static struct usb_function *pwfu_f;

static int pwfu_driver_bind(struct usb_composite_dev *cdev)
{
	int id, rc;

	dprintf("\n");
	id = usb_string_id(cdev);
	if (id < 0)
		return id;
	pwfu_string_tab[STRING_MANF].id = id;
	pwfu_dev_desc.iManufacturer = id;

	id = usb_string_id(cdev);
	if (id < 0)
		return id;
	pwfu_string_tab[STRING_PROD].id = id;
	pwfu_dev_desc.iProduct = id;

	id = usb_string_id(cdev);
	if (id < 0)
		return id;
	pwfu_string_tab[STRING_CONF].id = id;
	pwfu_config_driver.iConfiguration = id;

	id = usb_string_id(cdev);
	if (id < 0)
		return id;
	pwfu_string_tab[STRING_INTF].id = id;
	pwfu_interface_desc.iInterface = id;

	rc = usb_add_config_only(cdev, &pwfu_config_driver);
	if (rc < 0)
		return rc;
	pwfu_fi = usb_get_function_instance("pwfu");
	if (IS_ERR(pwfu_fi))
		return PTR_ERR(pwfu_fi);
	pwfu_f = usb_get_function(pwfu_fi);
	if (IS_ERR(pwfu_f))
		return PTR_ERR(pwfu_f);
	rc = usb_add_function(&pwfu_config_driver, pwfu_f);
	return rc;
}

static int pwfu_driver_unbind(struct usb_composite_dev *cdev)
{
	usb_put_function(pwfu_f);
	usb_put_function_instance(pwfu_fi);

	return 0;
}

static struct usb_composite_driver pwfu_driver = {
	.name = "pwfu",
	.dev = &pwfu_dev_desc,
	.strings = pwfu_strings,
	.max_speed = USB_SPEED_HIGH,
	.bind = pwfu_driver_bind,
	.unbind = pwfu_driver_unbind,
};

int usb_pwfu_register(void)
{
	int rc;

	dprintf("\n");
	pwfu_exit = 0;
	rc = usb_composite_probe(&pwfu_driver);
	if (rc)
		return rc;

	for (;;) {
		usb_gadget_poll();
		if (ctrlc() || pwfu_exit)
			break;
	}
	usb_composite_unregister(&pwfu_driver);
	return 0;
}

static void pwfu_free_func(struct usb_function *f)
{
	struct f_pwfu *pwfu = func_to_pwfu(f);

	free(pwfu);
}

static struct usb_function *pwfu_alloc_func(struct usb_function_instance *fi)
{
	struct f_pwfu *pwfu;

	pwfu = xzalloc(sizeof(*pwfu));
	pwfu->func.name = "PWFU";
	// pwfu->func.strings = ;
	pwfu->func.bind = pwfu_bind;
	pwfu->func.unbind = pwfu_unbind;
	pwfu->func.set_alt = pwfu_set_alt;
	pwfu->func.setup = pwfu_setup;
	pwfu->func.disable = pwfu_disable;
	pwfu->func.free_func = pwfu_free_func;
	pwfu->tmpfd = -1;

	return &pwfu->func;
}

static void pwfu_free_instance(struct usb_function_instance *fi)
{
	kfree(fi);
}

static struct usb_function_instance *pwfu_alloc_instance(void)
{
	struct usb_function_instance *fi;

	fi = xzalloc(sizeof(*fi));
	fi->free_func_inst = pwfu_free_instance;
	return fi;
}

DECLARE_USB_FUNCTION_INIT(pwfu, pwfu_alloc_instance, pwfu_alloc_func);
