/*
* @fileoverview Clase para implementar un clasificador de Análisis Discriminante Lineal (LDA)
*
* Generado como parte del proyecto Nican.AI
*
* @author Alexis B.
* @version 1.0
* @date 2024-06-20
*
* Dependencias:
* - mathjs
*
*/
/* Se importa la biblioteca mathjs desde un CDN para realizar cálculos matemáticos avanzados.*/
import * as math from 'https://cdn.jsdelivr.net/npm/mathjs@10.0.0/dist/math.min.js';
/**
* @class La clase LDA encapsula la lógica necesaria para entrenar un modelo LDA, proyectar puntos y realizar predicciones.
*/
class LDA {
// Iicializa la clase LDA con las clases proporcionadas por el ARFF y calcula los parámetros necesarios para cada par de clases.
constructor(...classes) {
if (classes.length < 2) {
throw new Error('Please pass at least 2 classes');
}
let numberOfPairs = classes.length * (classes.length - 1) / 2;
let pair1 = 0;
let pair2 = 1;
let pairs = new Array(numberOfPairs);
for (let i = 0; i < numberOfPairs; i++) {
pairs[i] = this.computeLdaParams(classes[pair1], classes[pair2], pair1, pair2);
pair2++;
if (pair2 === classes.length) {
pair1++;
pair2 = pair1 + 1;
}
}
this.pairs = pairs;
this.numberOfClasses = classes.length;
}
/**
* @returns {String} El nombre del clasificador para identificarlo.
*/
getClassifierName() {
return "LDA";
}
/**
* @description El método train entrena el modelo con datos de entrada X y sus respectivas etiquetas y.
* Además, organiza los datos por clase y calcula los parámetros LDA para cada par de clases.
* @param {Array} X representa las características de las instancias.
* @param {Array} y representa las etiquetas de clase correspondientes.
*/
train(X, y) {
const data = X.map((features, index) => [...features, y[index]]);
const uniqueClasses = [...new Set(y)];
if (uniqueClasses.length < 2) {
throw new Error('Please provide at least two distinct classes.');
}
let classData = uniqueClasses.map(cls => data.filter(row => row[row.length - 1] === cls).map(row => row.slice(0, -1)));
let numberOfPairs = uniqueClasses.length * (uniqueClasses.length - 1) / 2;
let pair1 = 0;
let pair2 = 1;
let pairs = new Array(numberOfPairs);
for (let i = 0; i < numberOfPairs; i++) {
pairs[i] = this.computeLdaParams(classData[pair1], classData[pair2], pair1, pair2);
pair2++;
if (pair2 === uniqueClasses.length) {
pair1++;
pair2 = pair1 + 1;
}
}
this.pairs = pairs;
this.numberOfClasses = uniqueClasses.length;
}
/**
* @description Este método calcula los parémetros theta y b que definen la línea discriminante para separar dos
* clases específicas. La media y la covarianza de cada clase se utilizan para calcular theta, que es el vector
* de pesos, y b, que es el sesgo.
* @param {Number} class1 Valor del par 1.
* @param {Number} class2 Valor del par 2.
* @param {Number} class1id Índice del par 1.
* @param {Number} class2id Índice del par 2.
* @returns Valores de las propiedades theta, b, class1id y class2id.
*
*/
computeLdaParams(class1, class2, class1id, class2id) {
let mu1 = math.transpose(math.mean(class1, 0));
let mu2 = math.transpose(math.mean(class2, 0));
let pooledCov = math.add(this.calculateCovariance(class1), this.calculateCovariance(class2));
const regularizationParam = 1e-5;
const size = math.size(pooledCov);
const identityMatrix = math.identity(size[0]);
pooledCov = math.add(pooledCov, math.multiply(regularizationParam, identityMatrix));
let theta = math.multiply(math.inv(pooledCov), math.subtract(mu2, mu1));
let b = math.multiply(-1, math.transpose(theta), math.add(mu1, mu2), 1 / 2);
return {
theta: theta,
b: b,
class1id: class1id,
class2id: class2id
};
}
/**
* @description Este método calcula la matriz de covarianza de un conjunto de datos.
* @param {Array} data Arreglo del conjunto de datos X.
* @returns Matriz de covarianza.
*/
calculateCovariance(data) {
const mean = math.mean(data, 0);
const diff = data.map(row => math.subtract(row, mean));
const covMatrix = math.multiply(math.transpose(diff), diff);
return math.divide(covMatrix, data.length - 1);
}
/**
* @description El método project proyecta un punto en el espacio de la línea discriminante definida por theta
* y b. Actualmente, solo soporta la proyección para dos clases.
* @param {Number} point Punto a proyectar.
* @returns Propiedad proyección
*/
project(point) {
if (this.pairs.length !== 1) {
throw new Error('LDA project currently only supports 2 classes.');
}
return this.projectPoint(point, this.pairs[0].theta, this.pairs[0].b);
}
/**
* @description Este método calcula la proyección de un punto dado un vector theta y un sesgo b.
* @param {Number} point Punto a proyectar.
* @param {Array} theta Vector theta.
* @param {Number} b Sesgo.
* @returns Arreglo de proyección.
*/
projectPoint(point, theta, b) {
return math.add(math.multiply(point, theta), b);
}
/**
* @description El método predict se utiliza para predecir la clase de nuevos datos. En el caso de dos clases,
* se usa una única línea discriminante, mientras que para múltiples clases se utilizan todas las líneas
* discriminantes calculadas.
* @param {Array} X representa las características de las instancias.
* @returns Predicción realizada de las instancias.
*/
predict(X) {
return X.map(instance => this.predictInstance(instance));
}
/**
* @description El método predictInstance realiza las operaciones necesarias para generar la predicción de cada instancia.
* @param {Array} instance Instancia que estará siendo evaluada.
* @returns Instancia predicha.
*/
predictInstance(instance) {
if (this.numberOfClasses === 2) {
return this.projectPoint(instance, this.pairs[0].theta, this.pairs[0].b) <= 0 ? 0 : 1;
}
let votes = new Array(this.numberOfClasses).fill(0);
for (let i = 0; i < this.pairs.length; i++) {
let params = this.pairs[i];
let projection = this.projectPoint(instance, params.theta, params.b);
if (projection <= 0) {
votes[params.class1id]++;
} else {
votes[params.class2id]++;
}
}
return votes.indexOf(Math.max(...votes));
}
}
export default LDA;
/*
// Ejemplo de uso
const lda = new LDA();
const X = [
[1, 2],
[2, 3],
[3, 4],
[4, 5]
];
const y = [0, 0, 1, 1];
lda.train(X, y);
const newInstances = [
[1, 3],
[3, 5]
];
const predictions = lda.predict(newInstances);
console.log("Predicciones:", predictions);
*/