diff --git a/examples/dualheaded_datautils/Example_.ipynb b/examples/dualheaded_datautils/Example_.ipynb new file mode 100644 index 0000000..a8a0b0a --- /dev/null +++ b/examples/dualheaded_datautils/Example_.ipynb @@ -0,0 +1,263 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing split functions and dataloading\n", + "\n", + "In this section, we test split functions (utils), custom datasets classes and dataloading (with standard pytorch dataloader). " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import print_function\n", + "import syft as sy\n", + "import torch\n", + "from torch.utils.data import Dataset\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.data._utils.collate import default_collate\n", + "from typing import List, Tuple\n", + "from uuid import UUID\n", + "from uuid import uuid4\n", + "from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler\n", + "\n", + "from abc import ABC, abstractmethod\n", + "from torchvision import datasets, transforms\n", + "\n", + "import utils\n", + "import dataloaders\n", + "\n", + "hook = sy.TorchHook(torch)\n", + "\n", + "transform = transforms.Compose([transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,)),\n", + " ])\n", + "trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)\n", + "#trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)\n", + "\n", + "# create some workers\n", + "client_1 = sy.VirtualWorker(hook, id=\"client_1\")\n", + "client_2 = sy.VirtualWorker(hook, id=\"client_2\")\n", + "\n", + "server = sy.VirtualWorker(hook, id= \"server\") \n", + "\n", + "data_owners = [client_1, client_2]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "#get a verticalFederatedDatase\n", + "vfd = utils.split_data_create_vertical_dataset(trainset, data_owners)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "loader = DataLoader(vfd, batch_size=4, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'client_1': [tensor([[[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + "\n", + "\n", + " [[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + "\n", + "\n", + " [[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]],\n", + "\n", + "\n", + " [[[-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " ...,\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.],\n", + " [-1., -1., -1., ..., -1., -1., -1.]]]]), tensor([9, 8, 2, 1]), tensor([51574., 39668., 24844., 32204.], dtype=torch.float64)], 'client_2': [tensor([[[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " ...,\n", + " [ 0.1686, 0.9922, 0.5137, ..., -1.0000, -1.0000, -1.0000],\n", + " [ 0.1686, 0.9922, 0.3412, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n", + "\n", + "\n", + " [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " ...,\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n", + "\n", + "\n", + " [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " ...,\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]],\n", + "\n", + "\n", + " [[[-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " ...,\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000],\n", + " [-1.0000, -1.0000, -1.0000, ..., -1.0000, -1.0000, -1.0000]]]]), tensor([9, 8, 2, 1]), tensor([51574., 39668., 24844., 32204.], dtype=torch.float64)]}\n" + ] + } + ], + "source": [ + "for el in loader: \n", + " print(el)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "#as in https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/Configuration_1.ipynb\n", + "from torch import nn, optim\n", + "\n", + "model_locations = [client_1, client_2, server]\n", + "\n", + "input_size= [28*14, 28*14]\n", + "hidden_sizes= {\"client_1\": [32, 64], \"client_2\":[32, 64], \"server\":[128, 64]}\n", + "\n", + "#create model segment for each worker\n", + "models = {\n", + " \"client_1\": nn.Sequential(\n", + " nn.Linear(input_size[0], hidden_sizes[\"client_1\"][0]),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_sizes[\"client_1\"][0], hidden_sizes[\"client_1\"][1]),\n", + " nn.ReLU(),\n", + " ),\n", + " \"client_2\": nn.Sequential(\n", + " nn.Linear(input_size[1], hidden_sizes[\"client_2\"][0]),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_sizes[\"client_2\"][0], hidden_sizes[\"client_2\"][1]),\n", + " nn.ReLU(),\n", + " ),\n", + " \"server\": nn.Sequential(\n", + " nn.Linear(hidden_sizes[\"server\"][0], hidden_sizes[\"server\"][1]),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_sizes[\"server\"][1], 10),\n", + " nn.LogSoftmax(dim=1)\n", + " )\n", + "}\n", + "\n", + "\n", + "\n", + "# Create optimisers for each segment and link to their segment\n", + "optimizers = [\n", + " optim.SGD(models[location.id].parameters(), lr=0.05,)\n", + " for location in model_locations\n", + "]\n", + "\n", + "\n", + "#send model segement to each client and server\n", + "for location in model_locations:\n", + " models[location.id].send(location)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/dualheaded_datautils/dataloaders.py b/examples/dualheaded_datautils/dataloaders.py new file mode 100644 index 0000000..8e445d5 --- /dev/null +++ b/examples/dualheaded_datautils/dataloaders.py @@ -0,0 +1,70 @@ +from __future__ import print_function +import syft as sy +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate +from typing import List, Tuple +from uuid import UUID +from uuid import uuid4 +from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler + +import datasets + + +"""I think this is not needed anymore""" + + +class SinglePartitionDataLoader(DataLoader): + """DataLoader for a single vertically-partitioned dataset""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + #self.collate_fn = id_collate_fn + + + +class VerticalFederatedDataLoader(DataLoader): + """Dataloader which batches data from a complete + set of vertically-partitioned datasets + + + DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, collate_fn=None, + pin_memory=False, drop_last=False, timeout=0, + worker_init_fn=None, *, prefetch_factor=2, + persistent_workers=False) + """ + + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, collate_fn=None, + pin_memory=False, drop_last=False, timeout=0, + worker_init_fn=None, *, prefetch_factor=2, + persistent_workers=False): + + self.dataset = dataset + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + + self.workers = dataset.workers + + self.batch_samplers = {} + for worker in self.workers: + data_range = range(len(self.dataset)) + if shuffle: + sampler = RandomSampler(data_range) + else: + sampler = SequentialSampler(data_range) + batch_sampler = BatchSampler(sampler, self.batch_size, drop_last) + self.batch_samplers[worker] = batch_sampler + + single_loaders = [] + for k in self.dataset.datasets.keys(): + single_loaders.append(SinglePartitionDataLoader(self.dataset.datasets[k], batch_sampler=self.batch_samplers[k])) + + self.single_loaders = single_loaders + + def __len__(self): + return sum(len(x) for x in self.dataset.datasets.values()) // len(self.workers) \ No newline at end of file diff --git a/examples/dualheaded_datautils/datasets.py b/examples/dualheaded_datautils/datasets.py new file mode 100644 index 0000000..3f27416 --- /dev/null +++ b/examples/dualheaded_datautils/datasets.py @@ -0,0 +1,151 @@ +from __future__ import print_function +import syft as sy +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate +from typing import List, Tuple +from uuid import UUID +from uuid import uuid4 +from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler + + +class BaseSet(Dataset): + def __init__(self, ids, values, worker_id=None, is_labels=False): + self.values_dic = {} + for i, l in zip(ids, values): + self.values_dic[i] = l + self.is_labels = is_labels + + self.ids = torch.Tensor(ids) + self.values = torch.Tensor(values) if is_labels else torch.stack(values) + + self.worker_id = None + if worker_id: + self.send_to_worker(worker_id) + + def send_to_worker(self, worker): + self.worker_id = worker + self.values_pointer = self.values.send(worker) + self.index_pointer = self.ids.send(worker) + return self.values_pointer, self.index_pointer + + def __getitem__(self, index): + """ + Args: + idx: index of the example we want to get + Returns: a tuple with data, label, index of a single example. + """ + return tuple([self.values[index], self.ids[index]]) + + def __len__(self): + """ + Returns: amount of samples in the dataset + """ + return self.values.shape[0] + + +class SampleSetWithLabels(Dataset): + def __init__(self, labelset, sampleset, worker_id=None): + #TO-DO: drop non-intersecting, now just assuming they are overlapping + #TO-DO: make sure values are sorted + self.labelset = labelset + self.sampleset = sampleset + + self.labels = labelset.values + self.values = sampleset.values + self.ids = sampleset.ids + + self.values_dic = {} + for k in labelset.values_dic.keys(): + self.values_dic[k] = tuple([sampleset.values_dic[k], labelset.values_dic[k]]) + + self.worker_id = None + if worker_id != None: + self.send_to_worker(worker_id) + + def send_to_worker(self, worker): + self.worker_id = worker + self.label_point, self.label_ix_pointer = self.labelset.send_to_worker(worker) + self.value_point, self.values_ix_pointer = self.sampleset.send_to_worker(worker) + return self.label_point, self.label_ix_pointer, self.value_point, self.values_ix_pointer + + + def __getitem__(self, index): + """ + Args: + idx: index of the example we want to get + Returns: a tuple with data, label, index of a single example. + """ + return tuple([self.values[index], self.labels[index], self.ids[index]]) + + def __len__(self): + """ + Returns: amount of samples in the dataset + """ + return self.values.shape[0] + + + +class VerticalFederatedDataset(): + """ + VerticalFederatedDataset, which acts as a dictionary between BaseVerticalDatasets, + already sent to remote workers, and the corresponding workers. + This serves as an input to VerticalFederatedDataLoader. + Same principle as in Syft 2.0 for FederatedDataset: + https://github.com/OpenMined/PySyft/blob/syft_0.2.x/syft/frameworks/torch/fl/dataset.py + + Args: + datasets: list of BaseVerticalDatasets. + """ + def __init__(self, datasets): + + self.datasets = {} #dictionary to keep track of BaseVerticalDatasets and corresponding workers + + indices_list = set() + + #take intersecting items + for dataset in datasets: + indices_list.update(dataset.ids) + self.datasets[dataset.worker_id] = dataset + + self.workers = self.__workers() + + #create a list of dictionaries + self.dict_items_list = [] + + for index in indices_list: + curr_dict = {} + for w in self.workers: + curr_dict[w] = tuple(list(self.datasets[w].values_dic[index.item()])+[index.item()]) + + self.dict_items_list.append(curr_dict) + + self.indices = list(indices_list) + + + def __workers(self): + """ + Returns: list of workers + """ + return list(self.datasets.keys()) + + def __getitem__(self, idx): + """ + Args: + worker_id[str,int]: ID of respective worker + Returns: + Get dataset item from different workers + """ + + return self.dict_items_list[idx] + + def __len__(self): + return len(self.indices) + + def __repr__(self): + + fmt_str = "FederatedDataset\n" + fmt_str += f" Distributed accross: {', '.join(str(x) for x in self.workers)}\n" + fmt_str += f" Number of datapoints: {self.__len__()}\n" + return fmt_str \ No newline at end of file diff --git a/examples/dualheaded_datautils/enhancedSplitWorkers.py b/examples/dualheaded_datautils/enhancedSplitWorkers.py new file mode 100644 index 0000000..5c5909b --- /dev/null +++ b/examples/dualheaded_datautils/enhancedSplitWorkers.py @@ -0,0 +1,54 @@ +from __future__ import print_function +import syft as sy +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate +from typing import List, Tuple +from uuid import UUID +from uuid import uuid4 +from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler + +import datasets + + +"""This is an experimental work-in-progress feature""" + + +class EnhanchedWorker(): + """Single worker with a role (label / data holder) and a model""" + + def __init__(self, worker, dataset, model, level=1): + + self.worker = worker + self.dataset = dataset #It can also be None, and then it would be only computational + self.model = model + + self.level = max(level, 0) #it should start from zero, otherwise throw error #TODO: implement error throwing + + + +class FederatedWorkerChain(): + + """Class wrapping all the workers with their corresponding model """ + def __init__(self, enhanchedWorkersList): + self.enhanchedWorkersList = enhanchedWorkersList + dic_workers = {} + for ew in enhanchedWorkersList: + if ew.level not in dic_workers.keys(): + dic_workers[ew.level] = [] + + dic_workers[ew.level].append(ew) + + self.dic_workers = dic_workers + + + #TODO: implement check that the level passed is valid + def get_same_level_en_workers(self, level): + return self.dic_workers[level] + + def get_previous_level_en_workers(self, level): + return self.dic_workers[level-1] + + def get_next_level_en_workers(self, level): + return self.dic_workers[level+1] diff --git a/examples/dualheaded_datautils/utils.py b/examples/dualheaded_datautils/utils.py new file mode 100644 index 0000000..afe2cfd --- /dev/null +++ b/examples/dualheaded_datautils/utils.py @@ -0,0 +1,120 @@ +from __future__ import print_function +import syft as sy +import torch +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate +from typing import List, Tuple +from uuid import UUID +from uuid import uuid4 +from torch.utils.data import SequentialSampler, RandomSampler, BatchSampler + +import dataloaders +import datasets +from datasets import * + +""" +Utility functions to split and distribute the data across different workers, +create vertical datasets and federate them. It also contains datasets and dataloader classes. +This code is meant to be used with dual-headed Neural Networks, where there are a bunch of different workers, +which agrees on the labels, and there is a server with the labels only. +Code built upon: +- Abbas Ismail's (@abbas5253) work on dual-headed NN. In particular, check Configuration 1: + https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/Configuration_1.ipynb +- Syft 2.0 Federated Learning dataset and dataloader: https://github.com/OpenMined/PySyft/tree/syft_0.2.x/syft/frameworks/torch/fl +TODO: + - replace ids with UUIDs + - create class for splitting the data + - check that / modify such that it works on data different than images + - dictionary keys should be worker ids, not workers themselves + + - the custom dataloder class is probably not needed anymore (Discuss) +""" + + + + +def split_data(dataset, worker_list=None, n_workers=2): + """ + Utility function to create a vertical split of the data. It also creates a numerical index to keep + track of the single data across different split. + Args: + dataset: an iterable object represent the dataset. Each element of the iterable + is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. + #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. + worker_list (optional): The list of VirtualWorkers to distribute the data vertically across. + n_workers(optional, default=2): The number of workers to split the data across. If worker_list is not passed, this is necessary to create the split. + label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) + #TODO: add the code to send labels to the server + Returns: + a dictionary holding as keys the workers passed as parameters, or integers corresponding to the split, + and as values a list of lists, where the first element are the single tensor of the data, the second the labels, + the third the index, which is to keep track of the same data point. + """ + + if worker_list is None: + worker_list = list(range(0, n_workers)) + + #counter to create the index of different data samples + idx = 0 + + #dictionary to accomodate the split data + dic_single_datasets = {} + for worker in worker_list: + """ + Each value is a list of three elements, to accomodate, in order: + - data examples (as tensors) + - label + - index + """ + dic_single_datasets[worker.id] = [] + + """ + Loop through the dataset to split the data and labels vertically across workers. + Splitting method from @abbas5253: https://github.com/abbas5253/SplitNN-for-Vertically-Partitioned-Data/blob/master/distribute_data.py + """ + label_list = [] + index_list = [] + for tensor, label in dataset: + height = tensor.shape[-1]//len(worker_list) + i = 0 + uuid_idx = uuid4() + for worker in worker_list[:-1]: + dic_single_datasets[worker.id].append(tensor[:, :, height * i : height * (i + 1)]) + i += 1 + + #add the value of the last worker / split + dic_single_datasets[worker_list[-1].id].append(tensor[:, :, height * (i) : ]) + label_list.append(label) + index_list.append(idx) + + idx += 1 + + return dic_single_datasets, label_list, index_list + + +def split_data_create_vertical_dataset(dataset, worker_list, label_server=None): + """ + Utility function to distribute the data vertically across workers and create a vertical federated dataset. + Args: + dataset: an iterable object represent the dataset. Each element of the iterable + is supposed to be a tuple of [tensor, label]. It could be an iterable "Dataset" object in PyTorch. + #TODO: add support for taking a Dataloader as input, to iterate over batches instead of single examples. + worker_list: The list of VirtualWorkers to distribute the data vertically across. + label_server (optional): the server which owns only the labels (e.g. in a dual-headed NN setting) + Returns: + a VerticalFederatedDataset. + """ + + #get a dictionary of workers --> data , label_list, index_list, ordered + dic_single_datasets, label_list, index_list = split_data(dataset, worker_list=worker_list) + + #instantiate BaseSets + label_set = BaseSet(index_list, label_list, is_labels=True) + base_datasets_list = [] + for w in dic_single_datasets.keys(): + bs = BaseSet(index_list, dic_single_datasets[w], is_labels=False) + base_datasets_list.append(SampleSetWithLabels(label_set, bs, worker_id=w)) + + #create VerticalFederatedDataset + return VerticalFederatedDataset(base_datasets_list)