/*
* @fileoverview Clase para realizar validación de referencia cruzada (Cross Validation)
* utilizando una cantidad n de pliegues (folds).
*
* Generado como parte del proyecto CAIM+BAYES+CV K-FOLDS.
*
* Basado en:
* NBcClasser.php
* Tarea 4. Aprendizaje y clasificación de documentos
* Métodos Probabilísticos para la Inteligencia Artificial
*
* @author Fernando MM
* @version 1.0
* @date 2024-04-11
*
* Dependencias:
* - naiveBayes.js ?? (puede ser cualquier clasificador con los métodos:
* classifier.train(X, y);
* classifier.predict(newX);
*
* Historial de cambios:
* - 1.0 (2024-04-11): Creación de la librería.
*
*/
/**
* @class La clase CrossValidation crea nuevas instancias de aplicaciones de validación cruzada con un clasificador, un conjunto de datos y número de pliegues.
*/
class CrossValidation {
constructor(classifier, data, folds = 10) {
this.classifier = classifier;
this.data = data;
this.folds = Math.min(folds, data.length); // Se asegura que la cantidad de folds no sobrepase la cantidad de datos.
this.bestModel = null;
this.bestAccuracy = 0;
this.averageAccuracy = 0;
}
/**
* @description Divide el conjunto de datos en pliegues.
* @return {object}
*/
splitData() {
const totalItems = this.data.length;
// Garantiza que todos los folds tengan al menos un elemento y que los remanentes se distribuyan uniformemente entre los primeros folds.
const baseFoldSize = Math.floor(totalItems / this.folds);
let remainder = totalItems % this.folds;
// Separa todas las columnas menos la última (X) y los valores de la última columna que tomará como etiquetas (y)
// Ver ejemplo de X, y al final del archivo.
var X = this.data.map(row => row.slice(0, -1));
var y = this.data.map(row => row.slice(-1)[0]);
const X_folds = [];
const y_folds = [];
let start = 0;
//Genera los folds
for (let i = 0; i < this.folds; i++) {
const foldSize = baseFoldSize + (remainder > 0 ? 1 : 0);
const end = start + foldSize;
X_folds.push(X.slice(start, end));
y_folds.push(y.slice(start, end));
start = end;
remainder--; // Disminuye el remanente tras ocupar un valor excedente para este fold
}
console.log('X_folds:');
console.log(X_folds);
console.log('y_folds:');
console.log(y_folds);
return {X_folds, y_folds};
}
/**
* @description Realiza la validación cruzada.
* @return {number}
*/
evaluate() {
const {X_folds, y_folds} = this.splitData();
let totalAccuracy = 0;
for (let i = 0; i < this.folds; i++) {
const X_train = [].concat(...X_folds.filter((_, index) => index !== i));
const y_train = [].concat(...y_folds.filter((_, index) => index !== i));
const X_test = X_folds[i];
const y_test = y_folds[i];
const foldClassifier = new this.classifier.constructor(); // Crear una nueva instancia del clasificador
foldClassifier.train(X_train, y_train);
const predictions = foldClassifier.predict(X_test);
const correct = y_test.filter((label, index) => label === predictions[index]).length;
const accuracy = correct / y_test.length;
totalAccuracy += accuracy;
if (accuracy > this.bestAccuracy) {
this.bestAccuracy = accuracy;
this.bestModel = foldClassifier;
}
}
this.averageAccuracy = totalAccuracy / this.folds;
return this.averageAccuracy;
}
}
/* EJEMPLO JAVASCRIPT
var arffDataSet = new ARFFDataSet();
arffDataSet.readARFF(file, function (attributes, data, labels) {
// La discretización es opcional
let discretizer = new DiscreteARFF(arffDataSet);
let discretizedData = discretizer.discretize();
const kFolds = 10;
const data = discretizedData.data.map((row, i) => [...row, arffDataSet.labels[i]]);
// Puede usarse cualquier otro clasificador con los métodos train(X, y) y predict(newX);
const classifier = new NaiveBayes();
const crossVal = new CrossValidation(classifier, data, kFolds);
console.log(`Resultados de validación CV:`);
console.log(crossVal);
});
*/
/* EJEMPLO X, y
// valores de atributos (todas las columnas menos la última)
const X = [
['red', 'small', 'circle'],
['blue', 'large', 'square'],
['green', 'medium', 'triangle'],
['yellow', 'small', 'circle'],
['red', 'large', 'square'],
['blue', 'medium', 'triangle']
];
// clases (última columna)
const y = ['A', 'B', 'C', 'A', 'B', 'C'];
**/