Line data Source code
1 : #include "../../includes/inference/LaneCurveFitter.hpp"
2 : #include <numeric>
3 : #include <cmath>
4 : #include <map>
5 : #include <set>
6 :
7 23 : LaneCurveFitter::LaneCurveFitter(float eps, int minSamples, int windows, int laneWidthPx)
8 23 : : dbscanEps(eps), dbscanMinSamples(minSamples), numWindows(windows), laneWidthPx(laneWidthPx) {}
9 :
10 0 : std::vector<cv::Point> LaneCurveFitter::extractLanePoints(const cv::Mat& binaryMask) {
11 0 : std::vector<cv::Point> points;
12 0 : for (int y = 0; y < binaryMask.rows; ++y) {
13 0 : for (int x = 0; x < binaryMask.cols; ++x) {
14 0 : if (binaryMask.at<uchar>(y, x) > 0)
15 0 : points.emplace_back(x, y);
16 : }
17 : }
18 0 : return points;
19 : }
20 :
21 0 : float interpolateXatY(const std::vector<cv::Point2f>& points, float y_query) {
22 0 : if (points.empty()) return 0.0f;
23 :
24 0 : for (size_t i = 1; i < points.size(); ++i) {
25 0 : float y1 = points[i - 1].y;
26 0 : float y2 = points[i].y;
27 :
28 0 : if ((y1 <= y_query && y_query <= y2) || (y2 <= y_query && y_query <= y1)) {
29 0 : float t = (y_query - y1) / (y2 - y1 + 1e-6f);
30 0 : float x1 = points[i - 1].x;
31 0 : float x2 = points[i].x;
32 0 : return x1 + t * (x2 - x1);
33 : }
34 : }
35 :
36 : // Extrapolate if y_query is outside range
37 0 : return points.back().x;
38 : }
39 :
40 : // Simple DBSCAN implementation (brute force)
41 0 : std::vector<int> LaneCurveFitter::dbscanCluster(const std::vector<cv::Point>& points, std::vector<int>& uniqueLabels) {
42 0 : const int n = points.size();
43 0 : std::vector<int> labels(n, -1);
44 0 : int clusterId = 0;
45 :
46 0 : for (int i = 0; i < n; ++i) {
47 0 : if (labels[i] != -1) continue;
48 :
49 0 : std::vector<int> neighbors;
50 0 : for (int j = 0; j < n; ++j) {
51 0 : if (cv::norm(points[i] - points[j]) <= dbscanEps)
52 0 : neighbors.push_back(j);
53 : }
54 :
55 0 : if (neighbors.size() < dbscanMinSamples)
56 0 : continue;
57 :
58 0 : labels[i] = clusterId;
59 0 : std::set<int> seeds(neighbors.begin(), neighbors.end());
60 0 : seeds.erase(i);
61 :
62 0 : while (!seeds.empty()) {
63 0 : int current = *seeds.begin();
64 0 : seeds.erase(seeds.begin());
65 :
66 0 : if (labels[current] == -1) {
67 0 : labels[current] = clusterId;
68 :
69 0 : std::vector<int> currentNeighbors;
70 0 : for (int j = 0; j < n; ++j) {
71 0 : if (cv::norm(points[current] - points[j]) <= dbscanEps)
72 0 : currentNeighbors.push_back(j);
73 : }
74 :
75 0 : if (currentNeighbors.size() >= dbscanMinSamples) {
76 0 : seeds.insert(currentNeighbors.begin(), currentNeighbors.end());
77 : }
78 : }
79 : }
80 :
81 0 : ++clusterId;
82 : }
83 :
84 0 : uniqueLabels.clear();
85 0 : for (int l : labels)
86 0 : if (l != -1)
87 0 : uniqueLabels.push_back(l);
88 0 : std::sort(uniqueLabels.begin(), uniqueLabels.end());
89 0 : uniqueLabels.erase(std::unique(uniqueLabels.begin(), uniqueLabels.end()), uniqueLabels.end());
90 :
91 0 : return labels;
92 : }
93 :
94 0 : std::pair<std::vector<float>, std::vector<float>> LaneCurveFitter::slidingWindowCentroids(const std::vector<cv::Point>& cluster, cv::Size imgSize, bool smooth) {
95 0 : std::vector<float> cx, cy;
96 0 : int h = imgSize.height / numWindows;
97 :
98 0 : for (int i = 0; i < numWindows; ++i) {
99 0 : int yLow = imgSize.height - (i + 1) * h;
100 0 : int yHigh = imgSize.height - i * h;
101 :
102 0 : std::vector<float> xAcc, yAcc;
103 0 : for (const auto& pt : cluster) {
104 0 : if (pt.y >= yLow && pt.y < yHigh) {
105 0 : xAcc.push_back(pt.x);
106 0 : yAcc.push_back(pt.y);
107 : }
108 : }
109 :
110 0 : if (!xAcc.empty()) {
111 0 : cx.push_back(std::accumulate(xAcc.begin(), xAcc.end(), 0.0f) / xAcc.size());
112 0 : cy.push_back(std::accumulate(yAcc.begin(), yAcc.end(), 0.0f) / yAcc.size());
113 : }
114 : }
115 :
116 0 : if (smooth && cx.size() >= 3) {
117 0 : for (size_t i = 1; i + 1 < cx.size(); ++i) {
118 0 : cx[i] = (cx[i - 1] + cx[i] + cx[i + 1]) / 3.0f;
119 : }
120 : }
121 :
122 0 : return {cy, cx};
123 : }
124 :
125 0 : bool LaneCurveFitter::isStraightLine(const std::vector<float>& y, const std::vector<float>& x, float threshold) {
126 0 : if (x.size() < 4) return false;
127 :
128 0 : float mean_x = std::accumulate(x.begin(), x.end(), 0.0f) / x.size();
129 0 : float mean_y = std::accumulate(y.begin(), y.end(), 0.0f) / y.size();
130 :
131 0 : float num = 0.0f, den_x = 0.0f, den_y = 0.0f;
132 0 : for (size_t i = 0; i < x.size(); ++i) {
133 0 : num += (x[i] - mean_x) * (y[i] - mean_y);
134 0 : den_x += (x[i] - mean_x) * (x[i] - mean_x);
135 0 : den_y += (y[i] - mean_y) * (y[i] - mean_y);
136 : }
137 :
138 0 : float corr = num / std::sqrt(den_x * den_y + 1e-6f);
139 0 : return std::abs(corr) > threshold;
140 : }
141 :
142 0 : bool LaneCurveFitter::hasSignFlip(const std::vector<float>& xVals) {
143 0 : std::vector<float> dx2(xVals.size());
144 0 : for (size_t i = 1; i + 1 < xVals.size(); ++i)
145 0 : dx2[i] = xVals[i + 1] + xVals[i - 1] - 2 * xVals[i];
146 :
147 0 : for (size_t i = 1; i < dx2.size(); ++i)
148 0 : if ((dx2[i] > 0) != (dx2[i - 1] > 0))
149 0 : return true;
150 0 : return false;
151 : }
152 :
153 0 : std::vector<float> LaneCurveFitter::fitCurve(const std::vector<float>& y, const std::vector<float>& x, const std::vector<float>& yEval) {
154 0 : if (y.size() < 3 || x.size() < 3) {
155 : // Fallback: return a straight horizontal line
156 0 : std::vector<float> fallback(yEval.size(), x.empty() ? 0.0f : x[0]);
157 0 : return fallback;
158 : }
159 :
160 0 : cv::Mat A(y.size(), 3, CV_32F);
161 0 : cv::Mat X(x);
162 :
163 0 : for (size_t i = 0; i < y.size(); ++i) {
164 0 : A.at<float>(i, 0) = y[i] * y[i];
165 0 : A.at<float>(i, 1) = y[i];
166 0 : A.at<float>(i, 2) = 1.0f;
167 : }
168 :
169 0 : cv::Mat coeffs;
170 0 : if (!cv::solve(A, X, coeffs, cv::DECOMP_SVD)) {
171 : // Fallback in case of failure
172 0 : std::vector<float> fallback(yEval.size(), x[0]);
173 0 : return fallback;
174 : }
175 :
176 0 : std::vector<float> result;
177 0 : for (float yv : yEval) {
178 0 : result.push_back(coeffs.at<float>(0) * yv * yv + coeffs.at<float>(1) * yv + coeffs.at<float>(2));
179 : }
180 0 : return result;
181 : }
182 :
183 :
184 0 : std::vector<LaneCurveFitter::LaneCurve> LaneCurveFitter::fitLanes(const cv::Mat& binaryMask) {
185 0 : std::vector<LaneCurve> lanes;
186 0 : auto points = extractLanePoints(binaryMask);
187 :
188 0 : std::vector<int> uniqueLabels;
189 0 : auto labels = dbscanCluster(points, uniqueLabels);
190 :
191 0 : for (int label : uniqueLabels) {
192 0 : std::vector<cv::Point> cluster;
193 0 : for (size_t i = 0; i < labels.size(); ++i)
194 0 : if (labels[i] == label)
195 0 : cluster.push_back(points[i]);
196 :
197 0 : auto [cy, cx] = slidingWindowCentroids(cluster, binaryMask.size(), false);
198 0 : if (cy.size() < 2) continue;
199 :
200 0 : std::vector<size_t> sortIdx(cy.size());
201 0 : std::iota(sortIdx.begin(), sortIdx.end(), 0);
202 0 : std::sort(sortIdx.begin(), sortIdx.end(), [&](size_t i, size_t j) { return cy[i] < cy[j]; });
203 :
204 0 : std::vector<float> y_sorted, x_sorted;
205 0 : for (auto i : sortIdx) {
206 0 : y_sorted.push_back(cy[i]);
207 0 : x_sorted.push_back(cx[i]);
208 : }
209 :
210 0 : auto testCurve = fitCurve(y_sorted, x_sorted, y_sorted);
211 0 : if (hasSignFlip(testCurve)) {
212 0 : std::tie(cy, cx) = slidingWindowCentroids(cluster, binaryMask.size(), true);
213 0 : sortIdx = std::vector<size_t>(cy.size());
214 0 : std::iota(sortIdx.begin(), sortIdx.end(), 0);
215 0 : std::sort(sortIdx.begin(), sortIdx.end(), [&](size_t i, size_t j) { return cy[i] < cy[j]; });
216 :
217 0 : y_sorted.clear(); x_sorted.clear();
218 0 : for (auto i : sortIdx) {
219 0 : y_sorted.push_back(cy[i]);
220 0 : x_sorted.push_back(cx[i]);
221 : }
222 : }
223 :
224 0 : float y_min = *std::min_element(y_sorted.begin(), y_sorted.end());
225 0 : float y_max = *std::max_element(y_sorted.begin(), y_sorted.end());
226 0 : std::vector<float> y_plot(300);
227 0 : float step = (y_max + 10 - (y_min - 30)) / 300.0f;
228 0 : for (int i = 0; i < 300; ++i)
229 0 : y_plot[i] = y_max + 10 - i * step;
230 :
231 0 : std::vector<float> x_plot = fitCurve(y_sorted, x_sorted, y_plot);
232 :
233 0 : std::vector<cv::Point2f> curve, cents;
234 0 : for (size_t i = 0; i < y_plot.size(); ++i)
235 0 : curve.emplace_back(x_plot[i], y_plot[i]);
236 0 : for (size_t i = 0; i < x_sorted.size(); ++i)
237 0 : cents.emplace_back(x_sorted[i], y_sorted[i]);
238 :
239 0 : lanes.push_back({cents, curve});
240 : }
241 :
242 0 : return lanes;
243 : }
244 :
245 0 : std::optional<LaneCurveFitter::CenterlineResult> LaneCurveFitter::computeVirtualCenterline(const std::vector<LaneCurve>& lanes, int imgWidth, int imgHeight) {
246 0 : const float centerX = imgWidth / 2.0f;
247 0 : LaneCurve left, right;
248 0 : std::vector<std::pair<float, LaneCurve>> candidates;
249 :
250 0 : for (const auto& lane : lanes) {
251 0 : std::vector<float> bottomXs;
252 0 : for (const auto& pt : lane.curve)
253 0 : if (pt.y >= imgHeight / 2) bottomXs.push_back(pt.x);
254 0 : if (bottomXs.empty()) continue;
255 :
256 0 : float avgX = std::accumulate(bottomXs.begin(), bottomXs.end(), 0.0f) / bottomXs.size();
257 0 : candidates.emplace_back(avgX, lane);
258 : }
259 :
260 0 : std::sort(candidates.begin(), candidates.end(), [](auto& a, auto& b) { return a.first < b.first; });
261 :
262 0 : for (const auto& [avgX, lane] : candidates) {
263 0 : if (avgX < centerX)
264 0 : left = lane;
265 0 : else if (!right.curve.size())
266 0 : right = lane;
267 : }
268 :
269 0 : std::vector<cv::Point2f> c1, c2, blended;
270 0 : if (!left.curve.empty() && !right.curve.empty()) {
271 0 : std::vector<float> y_common(300);
272 0 : float y_start = imgHeight - 1;
273 0 : float y_end = std::max(left.curve.back().y, right.curve.back().y);
274 0 : float dy = (y_start - y_end) / 299.0f;
275 0 : for (int i = 0; i < 300; ++i)
276 0 : y_common[i] = y_start - i * dy;
277 :
278 0 : std::vector<float> xl(300), xr(300);
279 0 : for (int i = 0; i < 300; ++i) {
280 0 : xl[i] = interpolateXatY(left.curve, y_common[i]);
281 0 : xr[i] = interpolateXatY(right.curve, y_common[i]);
282 : }
283 :
284 0 : for (int i = 0; i < 300; ++i) {
285 0 : float mid = (xl[i] + xr[i]) / 2.0f;
286 0 : float w = static_cast<float>(i) / 299.0f;
287 0 : float blendX = w * mid + (1 - w) * centerX;
288 :
289 0 : c1.emplace_back(mid, y_common[i]);
290 0 : c2.emplace_back(centerX, y_common[i]);
291 0 : blended.emplace_back(blendX, y_common[i]);
292 : }
293 0 : return CenterlineResult{blended, c1, c2};
294 : }
295 :
296 0 : return std::nullopt;
297 : }
|