Quantcast
Channel: Machine Learning
Viewing all articles
Browse latest Browse all 62787

Can't figure out how to properly implement simple gradient descent

$
0
0

Hi r/machinelearning, I'm in the process of learning Machine Learning for my own personal curiosity. I'm loosely using Andrew Ng's lectures and notes from CS229, but am not taking the course.

I found a data set online with various car properties including miles per gallon. I'm trying to use the car properties to predict miles per gallon. This is the equation I'm using (or at least attempting to use):

equation: http://i.imgur.com/qVRfN.png

source: http://cs229.stanford.edu/notes/cs229-notes1.pdf

What I'm finding is that the error does go down when this equation is used. Also, I end up with negative weights for "Acceleration" and "Weight" which makes sense (higher acceleration and weight leads to lower MPG, that checks out). Once I actually apply the weights to create predictions, they are entirely off.

Is this a common problem for beginners? Any clear indication what I might be doing incorrectly? Here is my code, any tips or anything would be greatly appreciated. Also, if you have any comments on my code please chime in. I am aware C# is not the best language for this and that I should be using matrices, other than that, I am interested in learning more from other people.

--START CODE--

using System; using System.Collections.Generic; using System.Text; using System.IO; namespace CarMpg { class Program { static void Main(string[] args) { var cars = Importer.GetCars(); var theta = Model.GetTheta(cars, .00000000001); var predictions = Model.GetPredictions(cars, theta); var outputPath = @"H:\_MachineLearning\CarMpg\Result.csv"; if (File.Exists(outputPath)) { File.Delete(outputPath); } var streamWriter = new StreamWriter(outputPath); foreach (var prediction in predictions) { streamWriter.WriteLine(prediction.Car.Mpg + "," + prediction.MpgPrediction); } streamWriter.Close(); } } } using System; using System.Collections.Generic; using System.Linq; using System.Text; namespace CarMpg { public static class Model { public static double GetH(Car car, Theta theta) { return theta.BaselineWeight + theta.AccelerationWeight * car.Acceleration + theta.CylindersWeight * car.Cylinders + theta.DisplacementWeight * car.Displacement + theta.HorsePowerWeight * car.HorsePower + theta.ModelYearWeight * car.ModelYear + theta.WeightWeight * car.Weight; } public static Theta GetTheta(List<Car> cars, double alpha) { var theta = new Theta(); var costs = new List<double>(); for (int i = 0; i < 1000; i++) { var errors = new List<Error>(); foreach (var car in cars) { var errorFactor = car.Mpg - GetH(car, theta); var error = new Error(car, errorFactor); errors.Add(error); } costs.Add(cars.Select(a => Math.Pow(a.Mpg - GetH(a, theta),2)).Sum()); var errorSum = new Error(); errorSum.BaselineWeight = errors.Sum(a => a.BaselineWeight); errorSum.AccelerationWeight = errors.Sum(a => a.AccelerationWeight); errorSum.CylindersWeight = errors.Sum(a => a.CylindersWeight); errorSum.DisplacementWeight = errors.Sum(a => a.DisplacementWeight); errorSum.HorsePowerWeight = errors.Sum(a => a.HorsePowerWeight); errorSum.ModelYearWeight = errors.Sum(a => a.ModelYearWeight); errorSum.WeightWeight = errors.Sum(a => a.WeightWeight); theta.BaselineWeight += errorSum.BaselineWeight * alpha; theta.AccelerationWeight += errorSum.AccelerationWeight * alpha; theta.CylindersWeight += errorSum.CylindersWeight * alpha; theta.DisplacementWeight += errorSum.DisplacementWeight * alpha; theta.HorsePowerWeight += errorSum.HorsePowerWeight * alpha; theta.ModelYearWeight += errorSum.ModelYearWeight * alpha; theta.WeightWeight += errorSum.WeightWeight * alpha; } return theta; } public static List<Prediction> GetPredictions(List<Car> cars, Theta theta) { return cars.Select(a => new Prediction(a, theta)).ToList(); } } } using System; using System.Collections.Generic; using System.Text; using System.IO; namespace CarMpg { public static class Importer { public static List<Car> GetCars() { var path = @"H:\_MachineLearning\CarMpg\Data.csv"; var streamReader = new StreamReader(path); var line = streamReader.ReadLine(); line = streamReader.ReadLine(); //Skip header var cars = new List<Car>(); while (line != null) { if (!line.Contains("?")) { var car = new Car(line); cars.Add(car); } line = streamReader.ReadLine(); } return cars; } } public class Car { public double Mpg; public double Cylinders; public double Displacement; public double HorsePower; public double Weight; public double Acceleration; public double ModelYear; public double Origin; public string Name; public Car(string line) { var splits = line.Split(','); Mpg = Convert.ToDouble(splits[0]); Cylinders = Convert.ToDouble(splits[1]); Displacement= Convert.ToDouble(splits[2]); HorsePower = Convert.ToDouble(splits[3]); Weight = Convert.ToDouble(splits[4]); Acceleration = Convert.ToDouble(splits[5]); ModelYear = Convert.ToDouble(splits[6]); Origin = Convert.ToDouble(splits[7]); Name = splits[8].Replace("\"", ""); } } public class Theta : Row { } public class Error : Row { public Error() { } public Error(Car car, double errorFactor) { BaselineWeight = errorFactor; CylindersWeight = car.Cylinders * errorFactor; DisplacementWeight = car.Displacement * errorFactor; HorsePowerWeight = car.HorsePower * errorFactor; WeightWeight = car.Weight * errorFactor; AccelerationWeight = car.Weight * errorFactor; ModelYearWeight = car.ModelYear * errorFactor; } } public class Row { public double BaselineWeight; public double CylindersWeight; public double DisplacementWeight; public double HorsePowerWeight; public double WeightWeight; public double AccelerationWeight; public double ModelYearWeight; public Row() { BaselineWeight = .5; CylindersWeight = .5; DisplacementWeight = .5; HorsePowerWeight = .5; WeightWeight = .5; AccelerationWeight = .5; ModelYearWeight = .5; } } public class Prediction { public Car Car; public double MpgPrediction; public Prediction(Car car, Theta theta) { Car = car; MpgPrediction = Model.GetH(car, theta); } } } 
submitted by leex1867
[link][16 comments]

Viewing all articles
Browse latest Browse all 62787

Trending Articles