Skip to content

Commit

Permalink
Merge pull request #840 from epfml/gdhf
Browse files Browse the repository at this point in the history
GDHF demo
  • Loading branch information
JulienVig authored Dec 3, 2024
2 parents c1e7bb3 + 75a67ec commit 8176df3
Show file tree
Hide file tree
Showing 20 changed files with 210 additions and 96 deletions.
15 changes: 12 additions & 3 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ interface BenchmarkArguments {
roundDuration: number
batchSize: number
save: boolean
host: URL
}

type BenchmarkUnsafeArguments = Omit<BenchmarkArguments, 'provider'> & {
Expand All @@ -22,12 +23,19 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'

const unsafeArgs = parse<BenchmarkUnsafeArguments>(
{
task: { type: String, alias: 't', description: 'Task: titanic, simple_face, cifar10 or lus_covid', defaultValue: 'simple_face' },
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 1 },
task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' },
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 },
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
typeLabel: "URL",
description: "Host to connect to",
defaultValue: new URL("http://localhost:8080"),
},

help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' }
},
{
Expand All @@ -42,6 +50,7 @@ const supportedTasks = Map(
defaultTasks.lusCovid,
defaultTasks.simpleFace,
defaultTasks.titanic,
defaultTasks.tinderDog,
).map((t) => [t.getTask().id, t]),
);

Expand Down Expand Up @@ -69,4 +78,4 @@ export const args: BenchmarkArguments = {
},
getModel: () => provider.getModel(),
},
};
};
17 changes: 5 additions & 12 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import type {
TaskProvider,
} from "@epfml/discojs";
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'
import { Server } from 'server'

import { getTaskData } from './data.js'
import { args } from './args.js'
Expand Down Expand Up @@ -49,23 +48,17 @@ async function main<D extends DataType>(
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
console.log({ args })

const [server, url] = await new Server().serve(undefined, provider)

const data = await getTaskData(task)

const dataSplits = await Promise.all(
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i))
)
const logs = await Promise.all(
Range(0, numberOfUsers).map(async (_) => await runUser(task, url, data)).toArray()
dataSplits.map(async data => await runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>))
)

if (args.save) {
const fileName = `${task.id}_${numberOfUsers}users.csv`;
await fs.writeFile(fileName, JSON.stringify(logs, null, 2));
}
console.log('Shutting down the server...')
await new Promise((resolve, reject) => {
server.once('close', resolve)
server.close(reject)
})
}

main(args.provider, args.numberOfUsers).catch(console.error)
main(args.provider, args.numberOfUsers).catch(console.error)
39 changes: 32 additions & 7 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import path from "node:path";

import { Dataset, processing } from "@epfml/discojs";
import type {
Dataset,
DataFormat,
DataType,
Image,
Task,
} from "@epfml/discojs";
import { loadCSV, loadImagesInDir } from "@epfml/discojs-node";
import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
import { Repeat } from "immutable";

async function loadSimpleFaceData(): Promise<Dataset<DataFormat.Raw["image"]>> {
Expand Down Expand Up @@ -36,10 +35,34 @@ async function loadLusCovidData(): Promise<Dataset<DataFormat.Raw["image"]>> {
return positive.chain(negative);
}

function loadTinderDogData(split: number): Dataset<DataFormat.Raw["image"]> {
const folder = path.join("..", "datasets", "tinder_dog", `${split + 1}`);
return loadCSV(path.join(folder, "labels.csv"))
.map(
(row) =>
[
processing.extractColumn(row, "filename"),
processing.extractColumn(row, "label"),
] as const,
)
.map(async ([filename, label]) => {
try {
const image = await Promise.any(
["png", "jpg", "jpeg"].map((ext) =>
loadImage(path.join(folder, `${filename}.${ext}`)),
),
);
return [image, label];
} catch {
throw Error(`${filename} not found in ${folder}`);
}
});
}

export async function getTaskData<D extends DataType>(
task: Task<D>,
taskID: Task<D>['id'], userIdx: number
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (task.id) {
switch (taskID) {
case "simple_face":
return (await loadSimpleFaceData()) as Dataset<DataFormat.Raw[D]>;
case "titanic":
Expand All @@ -52,7 +75,9 @@ export async function getTaskData<D extends DataType>(
).zip(Repeat("cat")) as Dataset<DataFormat.Raw[D]>;
case "lus_covid":
return (await loadLusCovidData()) as Dataset<DataFormat.Raw[D]>;
case "tinder_dog":
return loadTinderDogData(userIdx) as Dataset<DataFormat.Raw[D]>;
default:
throw new Error(`Data loader for ${task.id} not implemented.`);
throw new Error(`Data loader for ${taskID} not implemented.`);
}
}
}
3 changes: 3 additions & 0 deletions datasets/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@

# LUS Covid
/lus_covid/

# GDHF demo
/tinder_dog/
5 changes: 5 additions & 0 deletions datasets/populate
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ rm archive.zip DeAI-testimages
mkdir -p wikitext
curl 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz' |
tar --extract --gzip --strip-components=1 -C wikitext

# tinder_dog
curl 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip' > tinder_dog.zip
unzip -u tinder_dog.zip
rm tinder_dog.zip
2 changes: 2 additions & 0 deletions discojs/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{
url.pathname += `tasks/${this.task.id}/model.json`

const response = await fetch(url);
if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`);

const encoded = new Uint8Array(await response.arrayBuffer())
return await serialization.model.decode(encoded)
}
Expand Down
3 changes: 1 addition & 2 deletions discojs/src/client/federated/federated_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { Client, shortenId } from "../client.js";
import { type, type ClientConnected } from "../messages.js";
import {
waitMessage,
waitMessageWithTimeout,
WebSocketServer,
} from "../event_connection.js";
import * as messages from "./messages.js";
Expand Down Expand Up @@ -75,7 +74,7 @@ export class FederatedClient extends Client {
const {
id, waitForMoreParticipants, payload,
round, nbOfParticipants
} = await waitMessageWithTimeout(this.server, type.NewFederatedNodeInfo);
} = await waitMessage(this.server, type.NewFederatedNodeInfo);

// This should come right after receiving the message to make sure
// we don't miss a subsequent message from the server
Expand Down
1 change: 1 addition & 0 deletions discojs/src/default_tasks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export { mnist } from './mnist.js'
export { simpleFace } from './simple_face.js'
export { titanic } from './titanic.js'
export { wikitext } from './wikitext.js'
export { tinderDog } from './tinder_dog.js'
84 changes: 84 additions & 0 deletions discojs/src/default_tasks/tinder_dog.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import * as tf from '@tensorflow/tfjs'

import type { Model, Task, TaskProvider } from '../index.js'
import { models } from '../index.js'

export const tinderDog: TaskProvider<'image'> = {
getTask (): Task<'image'> {
return {
id: 'tinder_dog',
displayInformation: {
taskTitle: 'GDHF 2024 | TinderDog',
summary: {
preview: 'Which dog is the cutest....or not?',
overview: "Binary classification model for dog cuteness."
},
model: 'The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 64x64 pixels and normalizes values between 0 and 1',
dataFormatInformation: 'Accepted image formats are .png .jpg and .jpeg.',
dataExampleText: '',
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog_preview.png',
sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip',
sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, pick one of the data splits (the folder 0 for example) and use the CSV option below to select the file named "labels.csv". You can now connect the images located in the same folder.'
},
trainingInformation: {
epochs: 10,
roundDuration: 2,
validationSplit: 0, // nicer plot for GDHF demo
batchSize: 10,
dataType: 'image',
IMAGE_H: 64,
IMAGE_W: 64,
LABEL_LIST: ['Cute dogs', 'Less cute dogs'],
scheme: 'federated',
aggregationStrategy: 'mean',
minNbOfParticipants: 3,
tensorBackend: 'tfjs'
}
}
},


async getModel(): Promise<Model<'image'>> {
const seed = 42 // set a seed to ensure reproducibility during GDHF demo
const imageHeight = this.getTask().trainingInformation.IMAGE_H
const imageWidth = this.getTask().trainingInformation.IMAGE_W
const imageChannels = 3

const model = tf.sequential()

model.add(
tf.layers.conv2d({
inputShape: [imageHeight, imageWidth, imageChannels],
kernelSize: 5,
filters: 8,
activation: 'relu',
kernelInitializer: tf.initializers.heNormal({ seed })
})
)
model.add(tf.layers.conv2d({
kernelSize: 5, filters: 16, activation: 'relu',
kernelInitializer: tf.initializers.heNormal({ seed })
}))
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }))
model.add(tf.layers.dropout({ rate: 0.25, seed }))

model.add(tf.layers.flatten())
model.add(tf.layers.dense({
units: 32, activation: 'relu',
kernelInitializer: tf.initializers.heNormal({ seed })
}))
model.add(tf.layers.dropout({rate:0.25, seed}))
model.add(tf.layers.dense({
units: 2, activation: 'softmax',
kernelInitializer: tf.initializers.heNormal({ seed })
}))

model.compile({
optimizer: tf.train.adam(0.0005),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
})

return Promise.resolve(new models.TFJS('image', model))
}
}
4 changes: 3 additions & 1 deletion discojs/src/task/task_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@ export async function pushTask<D extends DataType>(
task: Task<D>,
model: Model<D>,
): Promise<void> {
await fetch(urlToTasks(base), {
const response = await fetch(urlToTasks(base), {
method: "POST",
body: JSON.stringify({
task,
model: await serialization.model.encode(model),
weights: await serialization.weights.encode(model.weights),
}),
});
if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`);
}

export async function fetchTasks(
base: URL,
): Promise<Map<TaskID, Task<DataType>>> {
const response = await fetch(urlToTasks(base));
if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`);
const tasks: unknown = await response.json();

if (!Array.isArray(tasks)) {
Expand Down
13 changes: 8 additions & 5 deletions discojs/src/training/disco.ts
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,21 @@ export class Disco<D extends DataType> extends EventEmitter<{
): Promise<
[
Dataset<Batched<DataFormat.ModelEncoded[D]>>,
Dataset<Batched<DataFormat.ModelEncoded[D]>>,
Dataset<Batched<DataFormat.ModelEncoded[D]>> | undefined,
]
> {
const { batchSize, validationSplit } = this.#task.trainingInformation;

const preprocessed = await processing.preprocess(this.#task, dataset);
let preprocessed = await processing.preprocess(this.#task, dataset);

const [training, validation] = (
preprocessed = (
this.#preprocessOnce
? new Dataset(await arrayFromAsync(preprocessed))
: preprocessed
).split(validationSplit);
)
if (validationSplit === 0) return [preprocessed.batch(batchSize).cached(), undefined];

const [training, validation] = preprocessed.split(validationSplit);

return [
training.batch(batchSize).cached(),
Expand All @@ -230,4 +233,4 @@ async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
const ret: T[] = [];
for await (const e of iter) ret.push(e);
return ret;
}
}
11 changes: 9 additions & 2 deletions server/src/controllers/federated_controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ export class FederatedController<
*/
#latestGlobalWeights: serialization.Encoded;

constructor(task: Task<D>, initialWeights: serialization.Encoded) {
constructor(task: Task<D>, private readonly initialWeights: serialization.Encoded) {
super(task)
this.#latestGlobalWeights = initialWeights
this.#latestGlobalWeights = this.initialWeights

// Save the latest weight updates to be able to send it to new or outdated clients
this.#aggregator.on('aggregation', async (weightUpdate) => {
Expand Down Expand Up @@ -145,6 +145,13 @@ export class FederatedController<
this.#aggregator.removeNode(clientId)
debug("client [%s] left", shortId)

// Reset the training session when all participants left
if (this.connections.size === 0) {
debug("All participants left. Resetting the training session")
this.#aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative')
this.#latestGlobalWeights = this.initialWeights
}

// Check if we dropped below the minimum number of participant required
// or if we are already waiting for new participants to join
if (this.connections.size >= minNbOfParticipants ||
Expand Down
6 changes: 3 additions & 3 deletions webapp/src/components/containers/ImageCard.vue
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<template>
<div
class="grid grid-cols-1 w-full bg-white dark:bg-slate-800 aspect-square rounded-xl drop-shadow-md hover:drop-shadow-xl transition duration-500 hover:scale-105 opacity-70 hover:opacity-100 shadow hover:shadow-lg"
class="grid grid-cols-1 w-full bg-white dark:bg-slate-800 aspect-square rounded-xl drop-shadow-md hover:drop-shadow-xl transition duration-500 hover:scale-125 shadow hover:shadow-lg hover:z-50"
>
<div class="grid grid-cols-1 gap-1 text-center content-center p-2 h-16">
<div class="grid grid-cols-1 gap-1 text-center content-center pt-2 h-8">
<slot name="title" />
<div class="text-sm">
<slot name="subtitle" />
Expand Down Expand Up @@ -35,4 +35,4 @@ function draw() {
if (context === null) throw new Error("canvas doesn't support 2D context");
context.putImageData(props.image, 0, 0);
}
</script>
</script>
Loading

0 comments on commit 8176df3

Please sign in to comment.